90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torchvision.utils import make_grid
|
|
from torch.optim import lr_scheduler
|
|
|
|
from model import *
|
|
from loss import *
|
|
from dataloader import *
|
|
|
|
|
|
def get_opt():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-n', '--num-workers', type=int, default=4)
|
|
parser.add_argument('-e', '--epoch', type=int, default=10)
|
|
parser.add_argument('-b', '--batch-size', type=int, default=128)
|
|
parser.add_argument('-t', '--test-batch-size', type=int, default=1)
|
|
parser.add_argument('-d', '--display-step', type=int, default=600)
|
|
opt = parser.parse_args()
|
|
return opt
|
|
|
|
|
|
def train(opt):
|
|
model = PLA_Attention_Model(100, 100, 50).cuda()
|
|
# model.load_state_dict(torch.load('checkpoint.pt'))
|
|
model.train()
|
|
|
|
train_dataset = NetworkFlowDataset('Thursday-WorkingHours.pcap',
|
|
'Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv')
|
|
train_data_loader = NetworkFlowDataloader(opt, train_dataset)
|
|
|
|
test_dataset = NetworkFlowDataset('Friday-WorkingHours.pcap', 'Friday-WorkingHours-Morning.pcap_ISCX.csv')
|
|
test_data_loader = NetworkFlowDataloader(opt, test_dataset)
|
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
criterion = Loss()
|
|
|
|
writer = SummaryWriter()
|
|
|
|
for epoch in range(opt.epoch):
|
|
for i in range(len(train_data_loader.data_loader)):
|
|
step = epoch * len(train_data_loader.data_loader) + i + 1
|
|
|
|
# load data
|
|
flow, label = train_data_loader.next_batch()
|
|
flow = flow.cuda()
|
|
label = label.cuda()
|
|
|
|
# train model
|
|
optim.zero_grad() # 梯度清零,不加则会累加梯度
|
|
result = model(flow)
|
|
loss = criterion(result, label)
|
|
loss.backward()
|
|
optim.step()
|
|
|
|
writer.add_scalar('loss', loss, step)
|
|
|
|
writer.close()
|
|
|
|
if step % opt.display_step == 0:
|
|
_, predicted = torch.max(result, 1)
|
|
total = label.size(0)
|
|
correct = (predicted == label).sum().item()
|
|
total_test = 0
|
|
correct_test = 0
|
|
for i in range(len(test_data_loader.data_loader)):
|
|
flow, label = test_data_loader.next_batch()
|
|
flow = flow.cuda()
|
|
label = label.cuda()
|
|
result = model(flow)
|
|
_, predicted = torch.max(result, 1)
|
|
total_test += label.size(0)
|
|
correct_test += (predicted == label).sum().item()
|
|
print(
|
|
'[Epoch {}] Loss : {:.2}, train_acc : {:.2}, test_acc : {:.2}'.format(epoch, loss, correct / total,
|
|
correct_test / total_test))
|
|
|
|
torch.save(model.state_dict(), 'checkpoint.pt')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
opt = get_opt()
|
|
train(opt)
|