89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
# coding=utf-8
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
|
|
|
|
def positional_encoding_1d(d_model, length):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param length: length of positions
|
|
:return: length*d_model position matrix
|
|
"""
|
|
if d_model % 2 != 0:
|
|
raise ValueError("Cannot use sin/cos positional encoding with "
|
|
"odd dim (got dim={:d})".format(d_model))
|
|
pe = torch.zeros(length, d_model)
|
|
position = torch.arange(0, length).unsqueeze(1)
|
|
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
|
-(math.log(10000.0) / d_model)))
|
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
|
|
|
return pe
|
|
|
|
|
|
class GPO(nn.Module):
|
|
def __init__(self, d_pe, d_hidden):
|
|
super(GPO, self).__init__()
|
|
self.d_pe = d_pe
|
|
self.d_hidden = d_hidden
|
|
|
|
self.pe_database = {}
|
|
self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True)
|
|
self.linear = nn.Linear(self.d_hidden, 1, bias=False)
|
|
|
|
# for p in self.parameters():
|
|
# p.requires_grad = False
|
|
|
|
def compute_pool_weights(self, lengths, features):
|
|
max_len = int(lengths.max())
|
|
pe_max_len = self.get_pe(max_len)
|
|
pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device)
|
|
mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device)
|
|
mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1)
|
|
pes = pes.masked_fill(mask == 0, 0)
|
|
|
|
self.gru.flatten_parameters()
|
|
packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
|
out, _ = self.gru(packed)
|
|
padded = pad_packed_sequence(out, batch_first=True)
|
|
out_emb, out_len = padded
|
|
out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2
|
|
scores = self.linear(out_emb)
|
|
scores[torch.where(mask == 0)] = -10000
|
|
|
|
weights = torch.softmax(scores / 0.1, 1)
|
|
return weights, mask
|
|
|
|
def forward(self, features, lengths):
|
|
"""
|
|
:param features: features with shape B x K x D
|
|
:param lengths: B x 1, specify the length of each data sample.
|
|
:return: pooled feature with shape B x D
|
|
"""
|
|
pool_weights, mask = self.compute_pool_weights(lengths, features)
|
|
|
|
features = features[:, :int(lengths.max()), :]
|
|
sorted_features = features.masked_fill(mask == 0, -10000)
|
|
sorted_features = sorted_features.sort(dim=1, descending=True)[0]
|
|
sorted_features = sorted_features.masked_fill(mask == 0, 0)
|
|
|
|
pooled_features = (sorted_features * pool_weights).sum(1)
|
|
return pooled_features, pool_weights
|
|
|
|
def get_pe(self, length):
|
|
"""
|
|
|
|
:param length: the length of the sequence
|
|
:return: the positional encoding of the given length
|
|
"""
|
|
length = int(length)
|
|
if length in self.pe_database:
|
|
return self.pe_database[length]
|
|
else:
|
|
pe = positional_encoding_1d(self.d_pe, length)
|
|
self.pe_database[length] = pe
|
|
return pe
|