YOLOv3-model-pruning/test_prune.py

278 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from cgi import print_directory
from models import *
from utils.utils import *
import torch
import numpy as np
from copy import deepcopy
from test import evaluate
from terminaltables import AsciiTable
import time
from utils.prune_utils import *
class opt():
model_def = "config/yolov3-ds-person.cfg" # yolov3-ds8-person.cfg
data_config = "config/smallperson.data" # smallperson.data
model = 'checkpoints/yolov3_ckpt_99_05181112.pth' # checkpoints/yolov3_ckpt.pth' # res8 yolov3_ckpt_99_06081725.pth # 2*res8+res4 yolov3_ckpt_99_05181112
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Darknet(opt.model_def).to(device) #加载模型
model.load_state_dict(torch.load(opt.model)) #加载权重
# print(model)
# for name in model.state_dict():
# print(name)
data_config = parse_data_config(opt.data_config)
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
# 用lambda表达式创建函数
eval_model = lambda model:evaluate(model, path=valid_path, iou_thres=0.5, conf_thres=0.01,
nms_thres=0.1, img_size=model.img_size, batch_size=8)
# print(eval_model)
obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()])
# 获取最初的模型评估和参数量
origin_model_metric = eval_model(model)
origin_nparameters = obtain_num_parameters(model)
CBL_idx, Conv_idx, prune_idx= parse_module_defs(model.module_defs)
# print(CBL_idx, Conv_idx, prune_idx)
# CBL_idx = [0, 1, 2, 3, 5, 6, 7, 9, 10, 12, 13, 16, 19, 22, 25, 28, 31, 34, 37, 38, 41, 44, 47, 50, 53, 56, 59, 62, 63, 66,
# 69, 72, 75, 76, 77, 78, 79, 80, 84, 87, 88, 89, 90, 91, 92, 96, 99, 100, 101, 102, 103, 104]
# 问题是 CBL_idx中没有14就是ds_conv那一层
# Conv_idx = [81, 93, 105]
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
#prune_idx =
#[0, 2, 6, 9, 13, 16, 19, 22, 25, 28, 31, 34, 38, 41, 44, 47, 50, 53, 56, 59, 63, 66, 69, 72, 75,
# 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
bn_weights = gather_bn_weights(model.module_list, prune_idx)
# print("model.module_list[0]:",model.module_list[0],"---",len(model.module_list[0]))
# print('bn_weights = ' ,bn_weights, "---",len(bn_weights))
sorted_bn = torch.sort(bn_weights)[0] # 对bn从小到大排序
# print('sorted_bn = ' ,sorted_bn)
# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
highest_thre = []
for idx in prune_idx:
highest_thre.append(model.module_list[idx][1].weight.data.abs().max().item())
highest_thre = min(highest_thre)
# 找到highest_thre对应的下标对应的百分比
percent_limit = (sorted_bn==highest_thre).nonzero().item()/len(bn_weights)
print(f'Threshold should be less than {highest_thre:.4f}.')
print(f'The corresponding prune ratio is {percent_limit:.3f}.')
#%%
def prune_and_eval(model, sorted_bn, percent=.0):
model_copy = deepcopy(model)
thre_index = int(len(sorted_bn) * percent)
# print('thre_index = ', thre_index) #thre_index = 11369
thre = sorted_bn[thre_index]
# print('thre = ', thre) #thre = tensor(0.0925)
# print(model_copy)
print(f'Channels with Gamma value less than {thre:.4f} are pruned!')
remain_num = 0
for idx in prune_idx:
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
# print("idx:",idx)
bn_module = model_copy.module_list[idx][1]
mask = obtain_bn_mask(bn_module, thre) #生成mask
# print(idx, ': ', len(mask))
# print(idx, ':', len(mask), ':', mask)
remain_num += int(mask.sum())
# BN层的权重(gamma)乘以这个mask就相当于剪枝了
bn_module.weight.data.mul_(mask) # 用mask对原始权重进行操作
# print("remain_num:",remain_num)
# print(model_copy)
# print(bn_module.weight.data)
# model_copy.module_list[0][1] = BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# print(model_copy.module_list[0][1].weight.data)
"""
model_copy.module_list[0][1].weight.data =
tensor([ 1.6235, 0.3670, 0.6683, 0.2203, -0.2113, 1.6356, 0.0717, 0.5802,
0.1437, -0.3640, 0.2322, 0.2651, 0.7316, 0.6135, 1.6100, 0.8620,
0.1987, 0.5357, 0.2006, 0.2127, 0.7190, -1.1396, -0.2585, -0.4673,
0.0498, 0.5148, 0.7377, 0.3179, 1.2934, 1.1743, 0.2840, 0.2782],
device='cuda:0')
"""
mAP = eval_model(model_copy)[2].mean() # 搞明白为什么 剪枝之后sample_metrics = [] 为空列表了 : 剪枝率过高会为空
# print(mAP)
print(f'Number of channels has been reduced from {len(sorted_bn)} to {remain_num}')
print(f'Prune ratio: {1-remain_num/len(sorted_bn):.3f}')
print(f'mAP of the pruned model is {mAP:.4f}')
return thre
#调用上面的函数
percent = 0.85 # 0.85
threshold = prune_and_eval(model, sorted_bn, percent)
# print(threshold)
#%%
def obtain_filters_mask(model, thre, CBL_idx, prune_idx):
pruned = 0
total = 0
num_filters = []
filters_mask = []
for idx in CBL_idx:
if(model.module_defs[idx]['type'] == 'ds_conv'):
bn_module = model.module_list[idx][2]
else:
bn_module = model.module_list[idx][1]
# print("idx",idx,"--bn_module :",model.module_list[idx][1])
# 如果idx是在剪枝下标的列表中就执行剪枝
if idx in prune_idx:
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
mask = obtain_bn_mask(bn_module, thre).cpu().numpy()
# 保留的通道数
remain = int(mask.sum())
# 剪掉的通道数
pruned = pruned + mask.shape[0] - remain
if remain == 0:
print("Channels would be all pruned!")
raise Exception
print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t '
f'remaining channel: {remain:>4d}')
else:
# 不用剪枝就全部保留
mask = np.ones(bn_module.weight.data.shape)
remain = mask.shape[0]
# print("idx:",idx,"--mask:",mask,"--lenofmask",len(mask)) # 怀疑这里因为ds_conv有两个卷积层 可能出现问题 但是14层的ds_conv mask长度是256 因为算的是bn层的mask所以为256这里值得注意明天来检查一下
total += mask.shape[0]
num_filters.append(remain) # 剪枝后还存在的滤波器
filters_mask.append(mask.copy()) # 剪枝的掩码mask
# print("num_filters:",len(num_filters),"CBL_idx:", len(CBL_idx)) 都是72
prune_ratio = pruned / total
print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}')
return num_filters, filters_mask
#调用上面的函数
num_filters, filters_mask = obtain_filters_mask(model, threshold, CBL_idx, prune_idx)
# print("num_filters : ", num_filters)
# num_filters = [29, 64, 29, 64, 128, 59, 128, 61, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 512,
# 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 1024, 512, 1024, 512, 1024, 512, 1024, 512, 1024,
# 264, 422, 237, 441, 263, 139, 256, 226, 352, 179, 278, 137, 215, 128, 111, 182, 90, 140, 54, 220]
#%%
# 映射成一个字典idx->mask
CBLidx2mask = {idx: mask for idx, mask in zip(CBL_idx, filters_mask)}
# print("CBLidx2mask:", CBLidx2mask)
# 获得剪枝后的模型
pruned_model = prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask)
# print("pruned_model:",pruned_model)
# 对剪枝后的模型进行评价
pruned_model_metric = eval_model(pruned_model)
print("mAP", f'{origin_model_metric[2].mean():.6f}', f'{pruned_model_metric[2].mean():.6f}')
# pruned_nparameters = obtain_num_parameters(pruned_model)
# print("pruned_nparameters:",pruned_nparameters)
#%%
# 拷贝一份原始模型的参数
compact_module_defs = deepcopy(model.module_defs)
# 遍历需要剪枝的CBL模块将通道数设置为剪枝后的通道数
for idx, num in zip(CBL_idx, num_filters):
# if compact_module_defs[idx]['type'] == 'ds_conv' :
# # continue
# compact_module_defs[idx]['filters'] = str(num)
assert compact_module_defs[idx]['type'] == 'convolutional' or compact_module_defs[idx]['type'] == 'ds_conv'
compact_module_defs[idx]['filters'] = str(num)
# 改了网络结构在这里相应的更改ds_conv通道数先试试更改如果后面不改就在前面mask那里改改
# assert compact_module_defs[idx]['type'] == 'ds_conv'
# compact_module_defs[idx]['filters'] = str(num)
# print("compact_module_defs:",compact_module_defs)
#%%
#compact_model是剪枝之后的网络的真实结构注意:上面的剪枝网络只是把那些需要剪枝的卷积层/BN层/激活层通道的权重置0了并没有保存剪枝后的网络)
# print("model.hyperparams:",model.hyperparams)
compact_model = Darknet([model.hyperparams.copy()] + compact_module_defs).to(device)
# compact_model_metric = eval_model(compact_model)
# print("compact_model_metric:", "mAP", f'{compact_model_metric[2].mean():.6f}')
# print("model.hyperparams:",model.hyperparams)
# print("compact_model:", compact_model)
# 计算参数量MFLOPs
compact_nparameters = obtain_num_parameters(compact_model)
# print("compact_nparameters:", compact_nparameters) 18404269已经是剪完枝后的参数了
# 为剪枝后的真实网络结构重新复制权重参数
init_weights_from_loose_model(compact_model, pruned_model, CBL_idx, Conv_idx, CBLidx2mask)
# 对比compact_model 与 prune_model每一层的权重
# for name, parameters in pruned_model.named_parameters():
# print(name, '1;', parameters.size())
# for name, parameters in compact_model.named_parameters():
# print(name, '2;', parameters.size())
# print("pruned_model:",pruned_model.state_dict())
# print("compact_model:",compact_model.state_dict())
#%%
random_input = torch.rand((1, 3, model.img_size, model.img_size)).to(device)
# 获取模型的推理时间
def obtain_avg_forward_time(input, model, repeat=200):
model.eval()
start = time.time()
with torch.no_grad():
for i in range(repeat):
output = model(input)
avg_infer_time = (time.time() - start) / repeat
return avg_infer_time, output
# 分别获取原始模型和剪枝后的模型的推理时间和输出
pruned_forward_time, pruned_output = obtain_avg_forward_time(random_input, pruned_model)
compact_forward_time, compact_output = obtain_avg_forward_time(random_input, compact_model)
# print("pruned_forward_time:",pruned_forward_time,"---compact_forward_time:", compact_forward_time)
# print("pruned_output:",pruned_output,"---compact_output:", compact_output)
# 计算原始模型推理结果和剪枝后的模型的推理结果,如果差距比较大说明哪里错了
# 先注释下面几行代码看能不能运行
diff = (pruned_output-compact_output).abs().gt(0.001).sum().item()
# print("diff:", diff)
if diff > 0:
print('Something wrong with the pruned model!')
#%%
# 在测试集上测试剪枝后的模型, 并统计模型的参数数量
compact_model_metric = eval_model(compact_model)
#%%
# 比较剪枝前后参数数量的变化、指标性能的变化
metric_table = [
["Metric", "Before", "After"],
["mAP", f'{origin_model_metric[2].mean():.6f}', f'{compact_model_metric[2].mean():.6f}'],
["Parameters", f"{origin_nparameters}", f"{compact_nparameters}"],
["Inference", f'{pruned_forward_time:.4f}', f'{compact_forward_time:.4f}']
]
print(AsciiTable(metric_table).table)
#%%
# 生成剪枝后的cfg文件并保存模型
pruned_cfg_name = opt.model_def.replace('/', f'/ds_1w_prune_{percent}_ckpt_99_05181112')
pruned_cfg_file = write_cfg(pruned_cfg_name, [model.hyperparams.copy()] + compact_module_defs)
print(f'Config file has been saved: {pruned_cfg_file}')
compact_model_name = opt.model.replace('/', f'/ds_1w_prune_{percent}_')
torch.save(compact_model.state_dict(), compact_model_name)
print(f'Compact model has been saved: {compact_model_name}')