Graduation_Project/WZM/model/GAT.py

138 lines
5.6 KiB
Python
Raw Normal View History

2024-06-24 19:41:48 +08:00
import math
import torch
from torch import nn
import torch.nn.functional as F
import copy
def l2norm(X, dim, eps=1e-8):
"""L2-normalize columns of X
"""
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
X = torch.div(X, norm)
return X
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, input_graph):
nodes_q = self.query(input_graph)
nodes_k = self.key(input_graph)
nodes_v = self.value(input_graph)
nodes_q_t = self.transpose_for_scores(nodes_q)
nodes_k_t = self.transpose_for_scores(nodes_k)
nodes_v_t = self.transpose_for_scores(nodes_v)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(nodes_q_t, nodes_k_t.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in GATModel forward() function)
attention_scores = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
nodes_new = torch.matmul(attention_probs, nodes_v_t)
nodes_new = nodes_new.permute(0, 2, 1, 3).contiguous()
new_nodes_shape = nodes_new.size()[:-2] + (self.all_head_size,)
nodes_new = nodes_new.view(*new_nodes_shape)
return nodes_new
class GATLayer(nn.Module):
def __init__(self, config):
super(GATLayer, self).__init__()
self.mha = MultiHeadAttention(config)
self.fc_in = nn.Linear(config.hidden_size, config.hidden_size)
self.bn_in = nn.BatchNorm1d(config.hidden_size)
self.dropout_in = nn.Dropout(config.hidden_dropout_prob)
self.fc_int = nn.Linear(config.hidden_size, config.hidden_size)
self.relu = nn.GELU()
self.fc_out = nn.Linear(config.hidden_size, config.hidden_size)
self.bn_out = nn.BatchNorm1d(config.hidden_size)
self.dropout_out = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_graph):
attention_output = self.mha(input_graph) # multi-head attention
attention_output = self.fc_in(attention_output)
attention_output = self.dropout_in(attention_output)
attention_output = self.bn_in((attention_output + input_graph).permute(0, 2, 1)).permute(0, 2, 1)
intermediate_output = self.fc_int(attention_output)
# intermediate_output = F.relu(intermediate_output)
intermediate_output = self.relu(intermediate_output)
intermediate_output = self.fc_out(intermediate_output)
intermediate_output = self.dropout_out(intermediate_output)
graph_output = self.bn_out((intermediate_output + attention_output).permute(0, 2, 1)).permute(0, 2, 1)
return graph_output
class GATopt(object):
def __init__(self, hidden_size, num_layers):
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_attention_heads = 1 # 1
self.hidden_dropout_prob = 0.2 # 0.2
self.attention_probs_dropout_prob = 0.2 # 0.2
class GAT(nn.Module):
def __init__(self, config_gat):
super(GAT, self).__init__()
layer = GATLayer(config_gat)
self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)])
self.weight1 = nn.Parameter(torch.randn(1))
# self.weight2 = nn.Parameter(torch.randn(1))
self.relu = nn.GELU()
def forward(self, input_graph):
B, C, H, W = input_graph.shape
# [B H*W C]
hidden_states = input_graph.reshape(B, C, -1, 1).squeeze(3).transpose(2,1).contiguous()
# batch = input_graph.shape[0]
for layer_module in self.encoder:
hidden_states = layer_module(hidden_states)
# [B, C, H, W]
hidden_states = hidden_states.transpose(2,1).unsqueeze(3).reshape(B, C, H, W).contiguous()
output_graph = self.relu(self.weight1 * hidden_states + (1 - self.weight1) * input_graph)
# output_graph = self.relu(self.weight1 * hidden_states + self.weight2 * input_graph)
# output_graph = self.relu(0.1 * hidden_states + 0.9 * input_graph)
return output_graph
class GAT_T(nn.Module):
def __init__(self, config_gat):
super(GAT_T, self).__init__()
layer = GATLayer(config_gat)
self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)])
def forward(self, input_graph):
hidden_states = input_graph
for layer_module in self.encoder:
hidden_states = layer_module(hidden_states)
return hidden_states # B, seq_len, D