77 lines
4.1 KiB
Python
77 lines
4.1 KiB
Python
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.nn.functional as F
|
|||
|
|
|||
|
|
|||
|
class PLA_Attention_Model(nn.Module):
|
|||
|
def __init__(self, byte_hidden_size, packet_hidden_size, packet_output_size, max_packet=1500, max_flow=128):
|
|||
|
super().__init__()
|
|||
|
self.byte_emb_size = 50
|
|||
|
self.byte_hidden_size = byte_hidden_size # 100
|
|||
|
self.packet_emb_size = 50
|
|||
|
self.packet_hidden_size = packet_hidden_size # 100
|
|||
|
self.packet_output_size = packet_output_size # 50
|
|||
|
self.max_packet = max_packet
|
|||
|
self.max_flow = max_flow
|
|||
|
|
|||
|
self.embedding = nn.Embedding(256 * 256, self.byte_emb_size)
|
|||
|
self.byte_GRU = nn.GRU(self.byte_emb_size, self.byte_hidden_size, bias=False, bidirectional=True)
|
|||
|
self.byte_attn = nn.Linear(self.byte_hidden_size * 2, self.packet_emb_size)
|
|||
|
self.packet_GRU = nn.GRU(self.packet_emb_size, self.packet_hidden_size, bias=False, bidirectional=True)
|
|||
|
self.packet_attn = nn.Linear(self.packet_hidden_size * 2, self.packet_output_size)
|
|||
|
self.classify = nn.Linear(self.packet_output_size, 11)
|
|||
|
|
|||
|
# lstm_output : [batch_size, n_step, self.byte_hidden_size * num_directions(=2)], F matrix
|
|||
|
def attention_net(self, lstm_output, final_state):
|
|||
|
# print("lstm_output: ", lstm_output.shape)
|
|||
|
# print("final_state: ", final_state.shape)
|
|||
|
# hidden : [batch_size, self.byte_hidden_size * num_directions(=2), 1(=n_layer)]
|
|||
|
hidden = final_state.view(-1, self.byte_hidden_size * 2, 1) # Tensor维度的重构,-1表示该维度取决于其他维度
|
|||
|
# attn_weights : [batch_size, n_step]
|
|||
|
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # 加权求和,第三维降维
|
|||
|
soft_attn_weights = F.softmax(attn_weights, 1)
|
|||
|
# [batch_size, self.byte_hidden_size * num_directions(=2), n_step] * [batch_size, n_step, 1]
|
|||
|
# = [batch_size, self.byte_hidden_size * num_directions(=2), 1]
|
|||
|
context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
|
|||
|
return context, soft_attn_weights # context : [batch_size, self.byte_hidden_size * num_directions(=2)]
|
|||
|
|
|||
|
def forward(self, flow):
|
|||
|
num_packet = flow.shape[0]
|
|||
|
batch_size = flow.shape[1]
|
|||
|
embedded_bytes_list = self.embedding(flow)
|
|||
|
encoding_bytes_list = torch.zeros((num_packet, batch_size, self.packet_emb_size))#.cuda() # [4, 3, 120, 50]
|
|||
|
for i, embedded_bytes in enumerate(embedded_bytes_list): # [3, 120, 50]
|
|||
|
h0_bytes = torch.randn(2, batch_size, self.byte_hidden_size)#.cuda() # [2,3,100]
|
|||
|
embedded_bytes = embedded_bytes.transpose(0, 1) # 交换 dim_0 和 dim_1 : [120, 3, 50]
|
|||
|
# print(embedded_bytes.shape, h0_bytes.shape)
|
|||
|
output, final_hidden_state = self.byte_GRU(embedded_bytes, h0_bytes) # [120, 3, 200], [2, 3, 100]
|
|||
|
output = output.permute(1, 0, 2) # [3, 120, 200]
|
|||
|
# print(output.shape)
|
|||
|
# print("final: ", final_hidden_state.shape)
|
|||
|
attn_output, attention = self.attention_net(output, final_hidden_state) # [3, 200]
|
|||
|
# print("attn_output", attn_output.shape)
|
|||
|
encoding_bytes_list[i] = self.byte_attn(attn_output) # [4, 3, 50] 最后一维维持不变
|
|||
|
|
|||
|
# print("2: ", encoding_bytes_list.shape) # [4, 3, 50]
|
|||
|
h0_packet = torch.randn(2, batch_size, self.packet_hidden_size)#.cuda() # [2,3,100]
|
|||
|
# print(h0_packet.shape)
|
|||
|
output, final_hidden_state = self.packet_GRU(encoding_bytes_list, h0_packet)
|
|||
|
# print("final_hidden_state: ", final_hidden_state.shape) # [2, 3, 100]
|
|||
|
output = output.permute(1, 0, 2) # 维度排列转置 [3, 4, 200]
|
|||
|
attn_output, attention = self.attention_net(output, final_hidden_state) # [3, 200]
|
|||
|
print("attn_output2: ", attn_output.shape)
|
|||
|
output = self.packet_attn(attn_output)
|
|||
|
classify = self.classify(output)
|
|||
|
|
|||
|
return classify
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
batch_size = 3
|
|||
|
model = PLA_Attention_Model(100, 100, 50)#.cuda()
|
|||
|
print(model)
|
|||
|
data = torch.randint(255, size=(4, batch_size, 120))#.cuda() # flow_len, batch_size, packet_len
|
|||
|
print(data.shape)
|
|||
|
res = model(data)
|
|||
|
print(res.shape)
|