Graduation_Project/LYZ/test.py

223 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 打印网络测试而已
import torch
import torch.nn as nn
from torch.nn.functional import normalize
class OneHotEncodingLayer(nn.Module):
def __init__(self, sz=256):
super(OneHotEncodingLayer, self).__init__()
self.size = sz
def forward(self, x):
return torch.nn.functional.one_hot(x, num_classes=self.size).float()
class ByteBlock(nn.Module):
def __init__(self, in_channels, nb_filter=(64, 100), filter_length=(3, 3),
subsample=(2, 1), pool_length=(2, 2)):
super(ByteBlock, self).__init__()
layers = []
for i in range(len(nb_filter)):
layers.append(nn.Conv1d(in_channels, nb_filter[i], kernel_size=filter_length[i],
padding=0, stride=subsample[i]))
layers.append(nn.Tanh())
if pool_length[i]:
layers.append(nn.MaxPool1d(pool_length[i]))
in_channels = nb_filter[i]
self.block = nn.Sequential(*layers)
self.global_pool = nn.AdaptiveMaxPool1d(1)
self.fc = nn.Linear(nb_filter[-1], 128)
def forward(self, x):
x = self.block(x)
x = self.global_pool(x).squeeze(dim=2)
x = torch.nn.functional.relu(self.fc(x))
return x
class FEMnet(nn.Module):
"""
Feature extraction module 定义
"""
def __init__(self, flow_len, packet_len, gru_units):
super(FEMnet, self).__init__()
self.packet_len = packet_len
self.flow_len = flow_len
self.batch_size = 10
self.gru_hidden_size = gru_units
self.rep_dim = gru_units
self.embedding = OneHotEncodingLayer(sz=256) # 独热编码
self.block2 = ByteBlock(self.packet_len, (128, 256), (5, 5), (1, 1), (2, 2))
self.block3 = ByteBlock(self.packet_len, (192, 320), (7, 5), (1, 1), (2, 2))
self.lstm_layer = nn.GRU(256, self.gru_hidden_size, dropout=0.1, bidirectional=True)
# self.dense_layer = nn.Linear(self.gru_hidden_size * 2, 5)
# self.output = nn.Softmax(dim=1) # Linear 自带Softmax 分类
def forward(self, x): # x: [batch_size, flow_len, packet_len] = [10, 14, 100]
embeddings_list = self.embedding(x) # [10, 14, 100, 256]
encoder_list = torch.zeros((self.batch_size, self.flow_len, 256))
for ix, embeddings in enumerate(embeddings_list): # [14, 100, 256]
# print("embeddings 1: ", embeddings.shape)
# embeddings = embeddings.permute(0, 2, 1)
# print("embeddings 2: ", embeddings.shape)
encoder1 = self.block2(embeddings) # [14, 128]
# print("encoder 1: ", encoder1.shape)
encoder2 = self.block3(embeddings) # [14, 128]
# print("encoder 2: ", encoder2.shape)
encoder = torch.cat([encoder1, encoder2], 1) # [14, 256]
# print("encoder : ", encoder.shape)
encoder_list[ix] = encoder
# rnn 网络输入:[seq_len, batch_size, hidden_size]
encoder_list = encoder_list.permute(1, 0, 2) # [10, 14, 256] -> [14, 10 ,256]
biLSTM, final_hidden_state = self.lstm_layer(encoder_list) # [14, 10, 184], [2, 10, 92]
biLSTM = biLSTM.permute(1, 0, 2) # [10, 14, 184]
# print("biLSTM: ", biLSTM.shape)
attn_output, attention = self.attention_net(biLSTM, final_hidden_state)
dense = self.dense_layer(attn_output)
# out = self.output(dense)
# out = dense[:, -1, :]
return dense
# lstm_output : [batch_size, n_step, self.byte_hidden_size * num_directions(=2)], F matrix
def attention_net(self, lstm_output, final_state):
# print("lstm_output: ", lstm_output.shape)
# print("final_state: ", final_state.shape)
# hidden : [batch_size, self.byte_hidden_size * num_directions(=2), 1(=n_layer)]
hidden = final_state.view(-1, self.gru_hidden_size * 2, 1) # Tensor维度的重构-1表示该维度取决于其他维度
# attn_weights : [batch_size, n_step]
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # 加权求和,第三维降维
soft_attn_weights = torch.nn.functional.softmax(attn_weights, 1)
# [batch_size, self.byte_hidden_size * num_directions(=2), n_step] * [batch_size, n_step, 1]
# = [batch_size, self.byte_hidden_size * num_directions(=2), 1]
context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
return context, soft_attn_weights # context : [batch_size, self.byte_hidden_size * num_directions(=2)]
class MyModel(nn.Module):
"""
Feature extraction module定义
"""
def __init__(self, binet, feature_dim, class_n):
super(MyModel, self).__init__()
self.binet = binet
self.feature_dim = feature_dim
self.cluster_num = class_n
""" 对比聚类 """
self.instance_projector = nn.Sequential( # 实例-单例
nn.Linear(self.binet.rep_dim, self.binet.rep_dim),
nn.ReLU(),
nn.Linear(self.binet.rep_dim, self.feature_dim),
)
self.cluster_projector = nn.Sequential( # 聚类-集群
nn.Linear(self.binet.rep_dim, self.binet.rep_dim),
nn.ReLU(),
nn.Linear(self.binet.rep_dim, self.cluster_num),
nn.Softmax(dim=1)
)
# 训练时执行
def forward(self, x_i, x_j): # x_i, x_j 来自同一数据的两种不同预处理方式得到的嵌入矩阵
h_i = self.resnet(x_i)
h_j = self.resnet(x_j)
z_i = normalize(self.instance_projector(h_i), dim=1)
z_j = normalize(self.instance_projector(h_j), dim=1)
c_i = self.cluster_projector(h_i)
c_j = self.cluster_projector(h_j)
return z_i, z_j, c_i, c_j
# 测试时执行
def forward_cluster(self, x):
h = self.resnet(x)
c = self.cluster_projector(h) # 聚类分布
c = torch.argmax(c, dim=1) # 得到每个样本的聚类标签。
return c
class MyModel1(nn.Module):
"""
模型定义
"""
def __init__(self, flow_len, packet_len, gru_units):
super(MyModel1, self).__init__()
self.packet_len = packet_len
self.flow_len = flow_len
self.batch_size = 10
self.gru_hidden_size = gru_units
self.embedding = OneHotEncodingLayer(sz=256) # 独热编码
self.block2 = ByteBlock(self.packet_len, (128, 256), (5, 5), (1, 1), (2, 2))
self.block3 = ByteBlock(self.packet_len, (192, 320), (7, 5), (1, 1), (2, 2))
self.lstm_layer = nn.GRU(256, self.gru_hidden_size, dropout=0.1, bidirectional=True)
self.dense_layer = nn.Linear(self.gru_hidden_size * 2, 5) # # 其实像特征聚类,输出聚类数
# self.output = nn.Softmax(dim=1)
def forward(self, x): # x: [batch_size, flow_len, packet_len] = [10, 14, 100]
embeddings_list = self.embedding(x) # [10, 14, 100, 256]
encoder_list = torch.zeros((self.batch_size, self.flow_len, 256))
for ix, embeddings in enumerate(embeddings_list): # [14, 100, 256]
# print("embeddings 1: ", embeddings.shape)
# embeddings = embeddings.permute(0, 2, 1)
# print("embeddings 2: ", embeddings.shape)
encoder1 = self.block2(embeddings) # [14, 128]
# print("encoder 1: ", encoder1.shape)
encoder2 = self.block3(embeddings) # [14, 128]
# print("encoder 2: ", encoder2.shape)
encoder = torch.cat([encoder1, encoder2], 1) # [14, 256]
# print("encoder : ", encoder.shape)
encoder_list[ix] = encoder
# rnn 网络输入:[seq_len, batch_size, hidden_size]
encoder_list = encoder_list.permute(1, 0, 2) # [10, 14, 256] -> [14, 10 ,256]
biLSTM, final_hidden_state = self.lstm_layer(encoder_list) # [14, 10, 184], [2, 10, 92]
biLSTM = biLSTM.permute(1, 0, 2) # [10, 14, 184]
# print("biLSTM: ", biLSTM.shape)
attn_output, attention = self.attention_net(biLSTM, final_hidden_state)
dense = self.dense_layer(attn_output)
# out = self.output(dense)
# out = dense[:, -1, :]
return dense
# lstm_output : [batch_size, n_step, self.byte_hidden_size * num_directions(=2)], F matrix
def attention_net(self, lstm_output, final_state):
# print("lstm_output: ", lstm_output.shape)
# print("final_state: ", final_state.shape)
# hidden : [batch_size, self.byte_hidden_size * num_directions(=2), 1(=n_layer)]
hidden = final_state.view(-1, self.gru_hidden_size * 2, 1) # Tensor维度的重构-1表示该维度取决于其他维度
# attn_weights : [batch_size, n_step]
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # 加权求和,第三维降维
soft_attn_weights = torch.nn.functional.softmax(attn_weights, 1)
# [batch_size, self.byte_hidden_size * num_directions(=2), n_step] * [batch_size, n_step, 1]
# = [batch_size, self.byte_hidden_size * num_directions(=2), 1]
context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
return context, soft_attn_weights # context : [batch_size, self.byte_hidden_size * num_directions(=2)]
if __name__ == '__main__':
#fe_model = FEMnet(14, 100, 92)
#model = MyModel(fe_model, 128, 5)
# print(model)
model = MyModel1(14, 100, 92)
data = torch.randint(255, size=(10, 14, 100)) # batch_size, flow_len, packet_len
model(data)
"""
params = list(model.parameters())
print(len(params))
print(params[0].size()) # (0): conv1's .weight
"""