# coding=utf-8 import torch import torch.nn as nn import math from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence def positional_encoding_1d(d_model, length): """ :param d_model: dimension of the model :param length: length of positions :return: length*d_model position matrix """ if d_model % 2 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model)) pe = torch.zeros(length, d_model) position = torch.arange(0, length).unsqueeze(1) div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) return pe class GPO(nn.Module): def __init__(self, d_pe, d_hidden): super(GPO, self).__init__() self.d_pe = d_pe self.d_hidden = d_hidden self.pe_database = {} self.gru = nn.GRU(self.d_pe, d_hidden, 1, batch_first=True, bidirectional=True) self.linear = nn.Linear(self.d_hidden, 1, bias=False) # for p in self.parameters(): # p.requires_grad = False def compute_pool_weights(self, lengths, features): max_len = int(lengths.max()) pe_max_len = self.get_pe(max_len) pes = pe_max_len.unsqueeze(0).repeat(lengths.size(0), 1, 1).to(lengths.device) mask = torch.arange(max_len).expand(lengths.size(0), max_len).to(lengths.device) mask = (mask < lengths.long().unsqueeze(1)).unsqueeze(-1) pes = pes.masked_fill(mask == 0, 0) self.gru.flatten_parameters() packed = pack_padded_sequence(pes, lengths.cpu(), batch_first=True, enforce_sorted=False) out, _ = self.gru(packed) padded = pad_packed_sequence(out, batch_first=True) out_emb, out_len = padded out_emb = (out_emb[:, :, :out_emb.size(2) // 2] + out_emb[:, :, out_emb.size(2) // 2:]) / 2 scores = self.linear(out_emb) scores[torch.where(mask == 0)] = -10000 weights = torch.softmax(scores / 0.1, 1) return weights, mask def forward(self, features, lengths): """ :param features: features with shape B x K x D :param lengths: B x 1, specify the length of each data sample. :return: pooled feature with shape B x D """ pool_weights, mask = self.compute_pool_weights(lengths, features) features = features[:, :int(lengths.max()), :] sorted_features = features.masked_fill(mask == 0, -10000) sorted_features = sorted_features.sort(dim=1, descending=True)[0] sorted_features = sorted_features.masked_fill(mask == 0, 0) pooled_features = (sorted_features * pool_weights).sum(1) return pooled_features, pool_weights def get_pe(self, length): """ :param length: the length of the sequence :return: the positional encoding of the given length """ length = int(length) if length in self.pe_database: return self.pe_database[length] else: pe = positional_encoding_1d(self.d_pe, length) self.pe_database[length] = pe return pe