YOLOv3-model-pruning/train.py

282 lines
11 KiB
Python
Raw Permalink Normal View History

2024-06-25 14:07:50 +08:00
from __future__ import division
from models import *
from utils.logger import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
from utils.prune_utils import *
from test import evaluate
# 调试用的模块reload用于代码热重载
from importlib import reload
import debug_utils
from terminaltables import AsciiTable
import os
import time
import datetime
import argparse
# import logging
import torch
from torch.utils.data import DataLoader
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=56, help="size of each image batch") # default=16
parser.add_argument("--model_def", type=str, default="config/yolov3-person.cfg", help="path to model definition file") #default="config/yolov3-hand.cfg"
parser.add_argument("--data_config", type=str, default="config/person.data", help="path to data config file") #default="config/oxfordhand.data"
# parser.add_argument("--pretrained_weights", type=str, default="weights/darknet53.conv.74",
parser.add_argument("--pretrained_weights", '-pre', type=str,
default="weights/yolov3.weights", help="if specified starts from checkpoint model")
parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=20, help="interval between saving model weights") #default=5
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--multiscale_training", default=False, help="allow for multi-scale training")
parser.add_argument("--debug_file", type=str, default="debug", help="enter ipdb if dir exists")
parser.add_argument('--learning_rate', '-lr', dest='lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.01, help='scale sparse rate')
opt = parser.parse_args()
print(opt)
logger = Logger("logs")
# 设置随机数种子
init_seeds()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%m%d%H%M')
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
# Get data configuration
data_config = parse_data_config(opt.data_config)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
# Initiate model
model = Darknet(opt.model_def).to(device)
model.apply(weights_init_normal)
#打印模型
# print(model)
# If specified we start from checkpoint
#是否加载预训练模型
# if opt.pretrained_weights:
# if opt.pretrained_weights.endswith(".pth"):
# model.load_state_dict(torch.load(opt.pretrained_weights))
# else:
# model.load_darknet_weights(opt.pretrained_weights)
_, _, prune_idx= parse_module_defs(model.module_defs) # 剪枝需要的参数S
# prune_idx = \
# [0,
# 2,
# 6, 9,
# 13, 16, 19, 22, 25, 28, 31, 34,
# 38, 41, 44, 47, 50, 53, 56, 59,
# 63, 66, 69, 72,
# 75, 76, 77, 78, 79, 80,
# #84,
# 87, 88, 89, 90, 91, 92,
# #96,
# 99,100,101,102,103, 104]
# Get dataloader 导入数据
dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn
)
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)
metrics = [
"grid_size",
"loss",
"x", "y", "w", "h",
"conf",
"cls", "cls_acc",
"recall50", "recall75",
"precision",
"conf_obj", "conf_noobj",
]
# 配置日志记录器
# logger = logging.getLogger('my_logger')
# logger.setLevel(logging.INFO)
# handler = logging.FileHandler('log.txt')
# handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'))
# logger.addHandler(handler)
# TFLOPS = 0
tmp_loss = 0
min_loss = 100
best_mAP = 0
for epoch in range(opt.epochs):
# 进入调试模式
if os.path.exists(opt.debug_file):
import ipdb
ipdb.set_trace()
model.train()
start_time = time.time()
sr_flag = get_sr_flag(epoch, opt.sr)
obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()])
origin_nparameters = obtain_num_parameters(model)
print("Parameters : ", f"{origin_nparameters}")
print("Model size:", sum([p.numel() for name, p in model.named_parameters()])) # 输出模型大小
for batch_i, (_, imgs, targets) in enumerate(dataloader):
batches_done = len(dataloader) * epoch + batch_i
imgs = imgs.to(device)
targets = targets.to(device)
loss, outputs = model(imgs, targets)
optimizer.zero_grad()
loss.backward()
BNOptimizer.updateBN(sr_flag, model.module_list, opt.s, prune_idx)
optimizer.step()
# acc = (outputs > 0.5).float().mean()
# logger.info('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Acc: {:.4f}'.format(epoch, epoch, i+1, len(dataloader), loss.item(), acc.item()))
# ----------------
# Log progress
# ----------------
log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader))
metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]]
# Log metrics at each YOLO layer
formats = {m: "%.6f" for m in metrics}
formats["grid_size"] = "%2d"
formats["cls_acc"] = "%.2f%%"
for metric in metrics:
row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers]
metric_table += [[metric, *row_metrics]]
log_str += AsciiTable(metric_table).table
log_str += f"\nTotal loss {loss.item()}"
#记录每个batch的最后一次loss
tmp_loss = loss.item()
# Determine approximate time left for epoch
epoch_batches_left = len(dataloader) - (batch_i + 1)
time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1))
log_str += f"\n---- ETA {time_left}"
print(log_str)
# Tensorboard logging
tensorboard_log = []
for i, yolo in enumerate(model.yolo_layers):
for name, metric in yolo.metrics.items():
# 选择部分指标写入tensorboard
if name not in {"grid_size", "x", "y", "w", "h", "cls_acc"}:
tensorboard_log += [(f"{name}_{i+1}", metric)]
tensorboard_log += [("loss", loss.item())]
tensorboard_log += [("lr", optimizer.param_groups[0]['lr'])]
logger.list_of_scalars_summary('train', tensorboard_log, batches_done)
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
# Evaluate the model on the validation set
precision, recall, AP, f1, ap_class = evaluate(
model,
path=valid_path,
iou_thres=0.5,
conf_thres=0.01,
nms_thres=0.5,
img_size=opt.img_size,
batch_size=8,
)
evaluation_metrics = [
("val_precision", precision.mean()),
("val_recall", recall.mean()),
("val_mAP", AP.mean()),
("val_f1", f1.mean()),
]
logger.list_of_scalars_summary('valid', evaluation_metrics, epoch)
# Print class APs and mAP
ap_table = [["Index", "Class name", "AP"]]
for i, c in enumerate(ap_class):
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
print(AsciiTable(ap_table).table)
print(f"---- mAP {AP.mean()}")
# 往tensorboard中记录bn权重分布
bn_weights = gather_bn_weights(model.module_list, prune_idx)
logger.writer.add_histogram('bn_weights/hist', bn_weights.numpy(), epoch, bins='doane')
if(best_mAP < AP.mean()):
best_mAP = AP.mean()
torch.save(model.state_dict(), f"checkpoints/xs_nopretrain_pt_best_mAP_{epoch}_yolov3_ckpt.pth") # 保存mAP最高的模型
if(tmp_loss < min_loss):
min_loss = tmp_loss
torch.save(model.state_dict(), f"checkpoints/xs_nopretrain_pt_min_loss_{epoch}_yolov3_ckpt.pth") # 保存loss最低的模型
if epoch % opt.checkpoint_interval == 0 or epoch == opt.epochs - 1:
torch.save(model.state_dict(), f"checkpoints/xs_nopretrain_yolov3_ckpt_{epoch}_{timestamp}.pth")
print(f"model has been saved:checkpoints/xs_nopretrain_yolov3_ckpt_{epoch}_{timestamp}.pth")
# obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()])
# origin_nparameters = obtain_num_parameters(model)
# print("Parameters : ", f"{origin_nparameters}")
# GPUs = torch.hub.list_hub_servers()
# flops = []
# parameters = []
# for server in GPUs:
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# flops.append(torch.Tensor(np.sum(np.square(model.parameters()[:3])).item()))
# parameters.append(torch.Tensor(model.parameters()))
# model.to(device)
# flops = np.array(flops)
# flops = flops.reshape(-1, flops.shape[-1])
# TFLOPS += flops.item()
# parameters = np.array(parameters)
# parameters = parameters.reshape(-1, parameters.shape[-1])
# print('GPUs:', GPUs)
# print('FLOPS:', flops.item())
# print('Params:', parameters.item())
# print('TFLOPS:', TFLOPS)