Graduation_Project/LYZ/model.py

77 lines
4.1 KiB
Python
Raw Normal View History

2024-06-29 14:23:18 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
class PLA_Attention_Model(nn.Module):
def __init__(self, byte_hidden_size, packet_hidden_size, packet_output_size, max_packet=1500, max_flow=128):
super().__init__()
self.byte_emb_size = 50
self.byte_hidden_size = byte_hidden_size # 100
self.packet_emb_size = 50
self.packet_hidden_size = packet_hidden_size # 100
self.packet_output_size = packet_output_size # 50
self.max_packet = max_packet
self.max_flow = max_flow
self.embedding = nn.Embedding(256 * 256, self.byte_emb_size)
self.byte_GRU = nn.GRU(self.byte_emb_size, self.byte_hidden_size, bias=False, bidirectional=True)
self.byte_attn = nn.Linear(self.byte_hidden_size * 2, self.packet_emb_size)
self.packet_GRU = nn.GRU(self.packet_emb_size, self.packet_hidden_size, bias=False, bidirectional=True)
self.packet_attn = nn.Linear(self.packet_hidden_size * 2, self.packet_output_size)
self.classify = nn.Linear(self.packet_output_size, 11)
# lstm_output : [batch_size, n_step, self.byte_hidden_size * num_directions(=2)], F matrix
def attention_net(self, lstm_output, final_state):
# print("lstm_output: ", lstm_output.shape)
# print("final_state: ", final_state.shape)
# hidden : [batch_size, self.byte_hidden_size * num_directions(=2), 1(=n_layer)]
hidden = final_state.view(-1, self.byte_hidden_size * 2, 1) # Tensor维度的重构-1表示该维度取决于其他维度
# attn_weights : [batch_size, n_step]
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # 加权求和,第三维降维
soft_attn_weights = F.softmax(attn_weights, 1)
# [batch_size, self.byte_hidden_size * num_directions(=2), n_step] * [batch_size, n_step, 1]
# = [batch_size, self.byte_hidden_size * num_directions(=2), 1]
context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
return context, soft_attn_weights # context : [batch_size, self.byte_hidden_size * num_directions(=2)]
def forward(self, flow):
num_packet = flow.shape[0]
batch_size = flow.shape[1]
embedded_bytes_list = self.embedding(flow)
encoding_bytes_list = torch.zeros((num_packet, batch_size, self.packet_emb_size))#.cuda() # [4, 3, 120, 50]
for i, embedded_bytes in enumerate(embedded_bytes_list): # [3, 120, 50]
h0_bytes = torch.randn(2, batch_size, self.byte_hidden_size)#.cuda() # [2,3,100]
embedded_bytes = embedded_bytes.transpose(0, 1) # 交换 dim_0 和 dim_1 : [120, 3, 50]
# print(embedded_bytes.shape, h0_bytes.shape)
output, final_hidden_state = self.byte_GRU(embedded_bytes, h0_bytes) # [120, 3, 200], [2, 3, 100]
output = output.permute(1, 0, 2) # [3, 120, 200]
# print(output.shape)
# print("final: ", final_hidden_state.shape)
attn_output, attention = self.attention_net(output, final_hidden_state) # [3, 200]
# print("attn_output", attn_output.shape)
encoding_bytes_list[i] = self.byte_attn(attn_output) # [4, 3, 50] 最后一维维持不变
# print("2: ", encoding_bytes_list.shape) # [4, 3, 50]
h0_packet = torch.randn(2, batch_size, self.packet_hidden_size)#.cuda() # [2,3,100]
# print(h0_packet.shape)
output, final_hidden_state = self.packet_GRU(encoding_bytes_list, h0_packet)
# print("final_hidden_state: ", final_hidden_state.shape) # [2, 3, 100]
output = output.permute(1, 0, 2) # 维度排列转置 [3, 4, 200]
attn_output, attention = self.attention_net(output, final_hidden_state) # [3, 200]
print("attn_output2: ", attn_output.shape)
output = self.packet_attn(attn_output)
classify = self.classify(output)
return classify
if __name__ == '__main__':
batch_size = 3
model = PLA_Attention_Model(100, 100, 50)#.cuda()
print(model)
data = torch.randint(255, size=(4, batch_size, 120))#.cuda() # flow_len, batch_size, packet_len
print(data.shape)
res = model(data)
print(res.shape)