Graduation_Project/LHL/gpo.py

89 lines
3.2 KiB
Python

# coding=utf-8
import torch
import torch.nn as nn
import math
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def positional_encoding_1d(d_model, length):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
pe = torch.zeros(length, d_model)
position = torch.arange(0, length).unsqueeze(1)
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
-(math.log(10000.0) / d_model)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe
class GPO(nn.Module):
def __init__(self, d_pe, d_hidden):
super(GPO, self).__init__()
self.d_pe = d_pe
self.d_hidden = d_hidden
self.pe_database = {}
self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
self.linear = nn.Linear(self.d_hidden, 1, bias=False)
# for p in self.parameters():
# p.requires_grad = False
def compute_pool_weights(self, lengths, features):
max_len = int(lengths.max())
pe_max_len = self.get_pe(max_len)
pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
pes = pes.masked_fill(mask == 0, 0)
self.gru.flatten_parameters()
packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
out, _ = self.gru(packed)
padded = pad_packed_sequence(out, batch_first=True)
out_emb, out_len = padded
out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
scores = self.linear(out_emb)
scores[torch.where(mask == 0)] = -10000
weights = torch.softmax(scores / 0.1, 1)
return weights, mask
def forward(self, features, lengths):
"""
:param features: features with shape B x K x D
:param lengths: B x 1, specify the length of each data sample.
:return: pooled feature with shape B x D
"""
pool_weights, mask = self.compute_pool_weights(lengths, features)
features = features[:, :int(lengths.max()), :]
sorted_features = features.masked_fill(mask == 0, -10000)
sorted_features = sorted_features.sort(dim=1, descending=True)[0]
sorted_features = sorted_features.masked_fill(mask == 0, 0)
pooled_features = (sorted_features * pool_weights).sum(1)
return pooled_features, pool_weights
def get_pe(self, length):
"""
:param length: the length of the sequence
:return: the positional encoding of the given length
"""
length = int(length)
if length in self.pe_database:
return self.pe_database[length]
else:
pe = positional_encoding_1d(self.d_pe, length)
self.pe_database[length] = pe
return pe