138 lines
5.6 KiB
Python
138 lines
5.6 KiB
Python
|
import math
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
import copy
|
||
|
|
||
|
def l2norm(X, dim, eps=1e-8):
|
||
|
"""L2-normalize columns of X
|
||
|
"""
|
||
|
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
||
|
X = torch.div(X, norm)
|
||
|
return X
|
||
|
|
||
|
class MultiHeadAttention(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(MultiHeadAttention, self).__init__()
|
||
|
|
||
|
self.num_attention_heads = config.num_attention_heads
|
||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
def transpose_for_scores(self, x):
|
||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||
|
x = x.view(*new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(self, input_graph):
|
||
|
nodes_q = self.query(input_graph)
|
||
|
nodes_k = self.key(input_graph)
|
||
|
nodes_v = self.value(input_graph)
|
||
|
|
||
|
nodes_q_t = self.transpose_for_scores(nodes_q)
|
||
|
nodes_k_t = self.transpose_for_scores(nodes_k)
|
||
|
nodes_v_t = self.transpose_for_scores(nodes_v)
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(nodes_q_t, nodes_k_t.transpose(-1, -2))
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
# Apply the attention mask is (precomputed for all layers in GATModel forward() function)
|
||
|
attention_scores = attention_scores
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
nodes_new = torch.matmul(attention_probs, nodes_v_t)
|
||
|
nodes_new = nodes_new.permute(0, 2, 1, 3).contiguous()
|
||
|
new_nodes_shape = nodes_new.size()[:-2] + (self.all_head_size,)
|
||
|
nodes_new = nodes_new.view(*new_nodes_shape)
|
||
|
return nodes_new
|
||
|
|
||
|
|
||
|
class GATLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(GATLayer, self).__init__()
|
||
|
self.mha = MultiHeadAttention(config)
|
||
|
|
||
|
self.fc_in = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.bn_in = nn.BatchNorm1d(config.hidden_size)
|
||
|
self.dropout_in = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
self.fc_int = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.relu = nn.GELU()
|
||
|
self.fc_out = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.bn_out = nn.BatchNorm1d(config.hidden_size)
|
||
|
self.dropout_out = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, input_graph):
|
||
|
attention_output = self.mha(input_graph) # multi-head attention
|
||
|
attention_output = self.fc_in(attention_output)
|
||
|
attention_output = self.dropout_in(attention_output)
|
||
|
attention_output = self.bn_in((attention_output + input_graph).permute(0, 2, 1)).permute(0, 2, 1)
|
||
|
intermediate_output = self.fc_int(attention_output)
|
||
|
# intermediate_output = F.relu(intermediate_output)
|
||
|
intermediate_output = self.relu(intermediate_output)
|
||
|
intermediate_output = self.fc_out(intermediate_output)
|
||
|
intermediate_output = self.dropout_out(intermediate_output)
|
||
|
graph_output = self.bn_out((intermediate_output + attention_output).permute(0, 2, 1)).permute(0, 2, 1)
|
||
|
return graph_output
|
||
|
|
||
|
|
||
|
|
||
|
class GATopt(object):
|
||
|
def __init__(self, hidden_size, num_layers):
|
||
|
self.hidden_size = hidden_size
|
||
|
self.num_layers = num_layers
|
||
|
self.num_attention_heads = 1 # 1
|
||
|
self.hidden_dropout_prob = 0.2 # 0.2
|
||
|
self.attention_probs_dropout_prob = 0.2 # 0.2
|
||
|
|
||
|
|
||
|
class GAT(nn.Module):
|
||
|
def __init__(self, config_gat):
|
||
|
super(GAT, self).__init__()
|
||
|
layer = GATLayer(config_gat)
|
||
|
self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)])
|
||
|
self.weight1 = nn.Parameter(torch.randn(1))
|
||
|
# self.weight2 = nn.Parameter(torch.randn(1))
|
||
|
self.relu = nn.GELU()
|
||
|
|
||
|
def forward(self, input_graph):
|
||
|
B, C, H, W = input_graph.shape
|
||
|
|
||
|
# [B H*W C]
|
||
|
hidden_states = input_graph.reshape(B, C, -1, 1).squeeze(3).transpose(2,1).contiguous()
|
||
|
# batch = input_graph.shape[0]
|
||
|
|
||
|
for layer_module in self.encoder:
|
||
|
hidden_states = layer_module(hidden_states)
|
||
|
|
||
|
# [B, C, H, W]
|
||
|
hidden_states = hidden_states.transpose(2,1).unsqueeze(3).reshape(B, C, H, W).contiguous()
|
||
|
|
||
|
output_graph = self.relu(self.weight1 * hidden_states + (1 - self.weight1) * input_graph)
|
||
|
# output_graph = self.relu(self.weight1 * hidden_states + self.weight2 * input_graph)
|
||
|
# output_graph = self.relu(0.1 * hidden_states + 0.9 * input_graph)
|
||
|
return output_graph
|
||
|
|
||
|
class GAT_T(nn.Module):
|
||
|
def __init__(self, config_gat):
|
||
|
super(GAT_T, self).__init__()
|
||
|
layer = GATLayer(config_gat)
|
||
|
self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)])
|
||
|
|
||
|
def forward(self, input_graph):
|
||
|
hidden_states = input_graph
|
||
|
for layer_module in self.encoder:
|
||
|
hidden_states = layer_module(hidden_states)
|
||
|
return hidden_states # B, seq_len, D
|