import torch import torch.nn as nn import torch.nn.functional as F class PLA_Attention_Model(nn.Module): def __init__(self, byte_hidden_size, packet_hidden_size, packet_output_size, max_packet=1500, max_flow=128): super().__init__() self.byte_emb_size = 50 self.byte_hidden_size = byte_hidden_size # 100 self.packet_emb_size = 50 self.packet_hidden_size = packet_hidden_size # 100 self.packet_output_size = packet_output_size # 50 self.max_packet = max_packet self.max_flow = max_flow self.embedding = nn.Embedding(256 * 256, self.byte_emb_size) self.byte_GRU = nn.GRU(self.byte_emb_size, self.byte_hidden_size, bias=False, bidirectional=True) self.byte_attn = nn.Linear(self.byte_hidden_size * 2, self.packet_emb_size) self.packet_GRU = nn.GRU(self.packet_emb_size, self.packet_hidden_size, bias=False, bidirectional=True) self.packet_attn = nn.Linear(self.packet_hidden_size * 2, self.packet_output_size) self.classify = nn.Linear(self.packet_output_size, 11) # lstm_output : [batch_size, n_step, self.byte_hidden_size * num_directions(=2)], F matrix def attention_net(self, lstm_output, final_state): # print("lstm_output: ", lstm_output.shape) # print("final_state: ", final_state.shape) # hidden : [batch_size, self.byte_hidden_size * num_directions(=2), 1(=n_layer)] hidden = final_state.view(-1, self.byte_hidden_size * 2, 1) # Tensor维度的重构,-1表示该维度取决于其他维度 # attn_weights : [batch_size, n_step] attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # 加权求和,第三维降维 soft_attn_weights = F.softmax(attn_weights, 1) # [batch_size, self.byte_hidden_size * num_directions(=2), n_step] * [batch_size, n_step, 1] # = [batch_size, self.byte_hidden_size * num_directions(=2), 1] context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2) return context, soft_attn_weights # context : [batch_size, self.byte_hidden_size * num_directions(=2)] def forward(self, flow): num_packet = flow.shape[0] batch_size = flow.shape[1] embedded_bytes_list = self.embedding(flow) encoding_bytes_list = torch.zeros((num_packet, batch_size, self.packet_emb_size))#.cuda() # [4, 3, 120, 50] for i, embedded_bytes in enumerate(embedded_bytes_list): # [3, 120, 50] h0_bytes = torch.randn(2, batch_size, self.byte_hidden_size)#.cuda() # [2,3,100] embedded_bytes = embedded_bytes.transpose(0, 1) # 交换 dim_0 和 dim_1 : [120, 3, 50] # print(embedded_bytes.shape, h0_bytes.shape) output, final_hidden_state = self.byte_GRU(embedded_bytes, h0_bytes) # [120, 3, 200], [2, 3, 100] output = output.permute(1, 0, 2) # [3, 120, 200] # print(output.shape) # print("final: ", final_hidden_state.shape) attn_output, attention = self.attention_net(output, final_hidden_state) # [3, 200] # print("attn_output", attn_output.shape) encoding_bytes_list[i] = self.byte_attn(attn_output) # [4, 3, 50] 最后一维维持不变 # print("2: ", encoding_bytes_list.shape) # [4, 3, 50] h0_packet = torch.randn(2, batch_size, self.packet_hidden_size)#.cuda() # [2,3,100] # print(h0_packet.shape) output, final_hidden_state = self.packet_GRU(encoding_bytes_list, h0_packet) # print("final_hidden_state: ", final_hidden_state.shape) # [2, 3, 100] output = output.permute(1, 0, 2) # 维度排列转置 [3, 4, 200] attn_output, attention = self.attention_net(output, final_hidden_state) # [3, 200] print("attn_output2: ", attn_output.shape) output = self.packet_attn(attn_output) classify = self.classify(output) return classify if __name__ == '__main__': batch_size = 3 model = PLA_Attention_Model(100, 100, 50)#.cuda() print(model) data = torch.randint(255, size=(4, batch_size, 120))#.cuda() # flow_len, batch_size, packet_len print(data.shape) res = model(data) print(res.shape)