278 lines
12 KiB
Python
278 lines
12 KiB
Python
|
from cgi import print_directory
|
|||
|
from models import *
|
|||
|
from utils.utils import *
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
from copy import deepcopy
|
|||
|
from test import evaluate
|
|||
|
from terminaltables import AsciiTable
|
|||
|
import time
|
|||
|
from utils.prune_utils import *
|
|||
|
|
|||
|
class opt():
|
|||
|
model_def = "config/yolov3-ds-person.cfg" # yolov3-ds8-person.cfg
|
|||
|
data_config = "config/smallperson.data" # smallperson.data
|
|||
|
model = 'checkpoints/yolov3_ckpt_99_05181112.pth' # checkpoints/yolov3_ckpt.pth' # res8 yolov3_ckpt_99_06081725.pth # 2*res8+res4 yolov3_ckpt_99_05181112
|
|||
|
|
|||
|
|
|||
|
#%%
|
|||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
model = Darknet(opt.model_def).to(device) #加载模型
|
|||
|
model.load_state_dict(torch.load(opt.model)) #加载权重
|
|||
|
# print(model)
|
|||
|
# for name in model.state_dict():
|
|||
|
# print(name)
|
|||
|
|
|||
|
data_config = parse_data_config(opt.data_config)
|
|||
|
valid_path = data_config["valid"]
|
|||
|
class_names = load_classes(data_config["names"])
|
|||
|
|
|||
|
# 用lambda表达式创建函数
|
|||
|
eval_model = lambda model:evaluate(model, path=valid_path, iou_thres=0.5, conf_thres=0.01,
|
|||
|
nms_thres=0.1, img_size=model.img_size, batch_size=8)
|
|||
|
# print(eval_model)
|
|||
|
|
|||
|
obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()])
|
|||
|
|
|||
|
# 获取最初的模型评估和参数量
|
|||
|
origin_model_metric = eval_model(model)
|
|||
|
origin_nparameters = obtain_num_parameters(model)
|
|||
|
|
|||
|
CBL_idx, Conv_idx, prune_idx= parse_module_defs(model.module_defs)
|
|||
|
# print(CBL_idx, Conv_idx, prune_idx)
|
|||
|
# CBL_idx = [0, 1, 2, 3, 5, 6, 7, 9, 10, 12, 13, 16, 19, 22, 25, 28, 31, 34, 37, 38, 41, 44, 47, 50, 53, 56, 59, 62, 63, 66,
|
|||
|
# 69, 72, 75, 76, 77, 78, 79, 80, 84, 87, 88, 89, 90, 91, 92, 96, 99, 100, 101, 102, 103, 104]
|
|||
|
# 问题是 CBL_idx中没有14就是ds_conv那一层
|
|||
|
|
|||
|
# Conv_idx = [81, 93, 105]
|
|||
|
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
|
|||
|
|
|||
|
#prune_idx =
|
|||
|
#[0, 2, 6, 9, 13, 16, 19, 22, 25, 28, 31, 34, 38, 41, 44, 47, 50, 53, 56, 59, 63, 66, 69, 72, 75,
|
|||
|
# 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
|
|||
|
|
|||
|
bn_weights = gather_bn_weights(model.module_list, prune_idx)
|
|||
|
# print("model.module_list[0]:",model.module_list[0],"---",len(model.module_list[0]))
|
|||
|
# print('bn_weights = ' ,bn_weights, "---",len(bn_weights))
|
|||
|
|
|||
|
sorted_bn = torch.sort(bn_weights)[0] # 对bn从小到大排序
|
|||
|
# print('sorted_bn = ' ,sorted_bn)
|
|||
|
|
|||
|
# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
|
|||
|
highest_thre = []
|
|||
|
for idx in prune_idx:
|
|||
|
highest_thre.append(model.module_list[idx][1].weight.data.abs().max().item())
|
|||
|
highest_thre = min(highest_thre)
|
|||
|
|
|||
|
# 找到highest_thre对应的下标对应的百分比
|
|||
|
percent_limit = (sorted_bn==highest_thre).nonzero().item()/len(bn_weights)
|
|||
|
|
|||
|
print(f'Threshold should be less than {highest_thre:.4f}.')
|
|||
|
print(f'The corresponding prune ratio is {percent_limit:.3f}.')
|
|||
|
|
|||
|
#%%
|
|||
|
def prune_and_eval(model, sorted_bn, percent=.0):
|
|||
|
model_copy = deepcopy(model)
|
|||
|
thre_index = int(len(sorted_bn) * percent)
|
|||
|
# print('thre_index = ', thre_index) #thre_index = 11369
|
|||
|
thre = sorted_bn[thre_index]
|
|||
|
# print('thre = ', thre) #thre = tensor(0.0925)
|
|||
|
# print(model_copy)
|
|||
|
|
|||
|
print(f'Channels with Gamma value less than {thre:.4f} are pruned!')
|
|||
|
|
|||
|
remain_num = 0
|
|||
|
for idx in prune_idx:
|
|||
|
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
|
|||
|
# print("idx:",idx)
|
|||
|
bn_module = model_copy.module_list[idx][1]
|
|||
|
|
|||
|
mask = obtain_bn_mask(bn_module, thre) #生成mask
|
|||
|
# print(idx, ': ', len(mask))
|
|||
|
# print(idx, ':', len(mask), ':', mask)
|
|||
|
remain_num += int(mask.sum())
|
|||
|
# BN层的权重(gamma)乘以这个mask,就相当于剪枝了
|
|||
|
bn_module.weight.data.mul_(mask) # 用mask对原始权重进行操作
|
|||
|
# print("remain_num:",remain_num)
|
|||
|
# print(model_copy)
|
|||
|
# print(bn_module.weight.data)
|
|||
|
# model_copy.module_list[0][1] = BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
|||
|
# print(model_copy.module_list[0][1].weight.data)
|
|||
|
"""
|
|||
|
model_copy.module_list[0][1].weight.data =
|
|||
|
tensor([ 1.6235, 0.3670, 0.6683, 0.2203, -0.2113, 1.6356, 0.0717, 0.5802,
|
|||
|
0.1437, -0.3640, 0.2322, 0.2651, 0.7316, 0.6135, 1.6100, 0.8620,
|
|||
|
0.1987, 0.5357, 0.2006, 0.2127, 0.7190, -1.1396, -0.2585, -0.4673,
|
|||
|
0.0498, 0.5148, 0.7377, 0.3179, 1.2934, 1.1743, 0.2840, 0.2782],
|
|||
|
device='cuda:0')
|
|||
|
"""
|
|||
|
|
|||
|
mAP = eval_model(model_copy)[2].mean() # 搞明白为什么 剪枝之后sample_metrics = [] 为空列表了 : 剪枝率过高会为空
|
|||
|
# print(mAP)
|
|||
|
|
|||
|
print(f'Number of channels has been reduced from {len(sorted_bn)} to {remain_num}')
|
|||
|
print(f'Prune ratio: {1-remain_num/len(sorted_bn):.3f}')
|
|||
|
print(f'mAP of the pruned model is {mAP:.4f}')
|
|||
|
|
|||
|
return thre
|
|||
|
|
|||
|
#调用上面的函数
|
|||
|
percent = 0.85 # 0.85
|
|||
|
threshold = prune_and_eval(model, sorted_bn, percent)
|
|||
|
|
|||
|
# print(threshold)
|
|||
|
|
|||
|
#%%
|
|||
|
def obtain_filters_mask(model, thre, CBL_idx, prune_idx):
|
|||
|
|
|||
|
pruned = 0
|
|||
|
total = 0
|
|||
|
num_filters = []
|
|||
|
filters_mask = []
|
|||
|
for idx in CBL_idx:
|
|||
|
if(model.module_defs[idx]['type'] == 'ds_conv'):
|
|||
|
bn_module = model.module_list[idx][2]
|
|||
|
else:
|
|||
|
bn_module = model.module_list[idx][1]
|
|||
|
# print("idx",idx,"--bn_module :",model.module_list[idx][1])
|
|||
|
# 如果idx是在剪枝下标的列表中,就执行剪枝
|
|||
|
if idx in prune_idx:
|
|||
|
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
|
|||
|
|
|||
|
mask = obtain_bn_mask(bn_module, thre).cpu().numpy()
|
|||
|
# 保留的通道数
|
|||
|
remain = int(mask.sum())
|
|||
|
# 剪掉的通道数
|
|||
|
pruned = pruned + mask.shape[0] - remain
|
|||
|
|
|||
|
if remain == 0:
|
|||
|
print("Channels would be all pruned!")
|
|||
|
raise Exception
|
|||
|
|
|||
|
print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t '
|
|||
|
f'remaining channel: {remain:>4d}')
|
|||
|
else:
|
|||
|
# 不用剪枝就全部保留
|
|||
|
mask = np.ones(bn_module.weight.data.shape)
|
|||
|
remain = mask.shape[0]
|
|||
|
|
|||
|
# print("idx:",idx,"--mask:",mask,"--lenofmask",len(mask)) # 怀疑这里因为ds_conv有两个卷积层 可能出现问题 但是14层的ds_conv mask长度是256 因为算的是bn层的mask,所以为256,这里值得注意!!明天来检查一下!
|
|||
|
|
|||
|
total += mask.shape[0]
|
|||
|
num_filters.append(remain) # 剪枝后还存在的滤波器
|
|||
|
filters_mask.append(mask.copy()) # 剪枝的掩码mask
|
|||
|
|
|||
|
# print("num_filters:",len(num_filters),"CBL_idx:", len(CBL_idx)) 都是72
|
|||
|
prune_ratio = pruned / total
|
|||
|
print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}')
|
|||
|
|
|||
|
return num_filters, filters_mask
|
|||
|
|
|||
|
#调用上面的函数
|
|||
|
num_filters, filters_mask = obtain_filters_mask(model, threshold, CBL_idx, prune_idx)
|
|||
|
# print("num_filters : ", num_filters)
|
|||
|
# num_filters = [29, 64, 29, 64, 128, 59, 128, 61, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 512,
|
|||
|
# 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 1024, 512, 1024, 512, 1024, 512, 1024, 512, 1024,
|
|||
|
# 264, 422, 237, 441, 263, 139, 256, 226, 352, 179, 278, 137, 215, 128, 111, 182, 90, 140, 54, 220]
|
|||
|
|
|||
|
#%%
|
|||
|
# 映射成一个字典,idx->mask
|
|||
|
CBLidx2mask = {idx: mask for idx, mask in zip(CBL_idx, filters_mask)}
|
|||
|
# print("CBLidx2mask:", CBLidx2mask)
|
|||
|
# 获得剪枝后的模型
|
|||
|
pruned_model = prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask)
|
|||
|
# print("pruned_model:",pruned_model)
|
|||
|
# 对剪枝后的模型进行评价
|
|||
|
pruned_model_metric = eval_model(pruned_model)
|
|||
|
print("mAP", f'{origin_model_metric[2].mean():.6f}', f'{pruned_model_metric[2].mean():.6f}')
|
|||
|
# pruned_nparameters = obtain_num_parameters(pruned_model)
|
|||
|
# print("pruned_nparameters:",pruned_nparameters)
|
|||
|
|
|||
|
#%%
|
|||
|
# 拷贝一份原始模型的参数
|
|||
|
compact_module_defs = deepcopy(model.module_defs)
|
|||
|
# 遍历需要剪枝的CBL模块,将通道数设置为剪枝后的通道数
|
|||
|
for idx, num in zip(CBL_idx, num_filters):
|
|||
|
# if compact_module_defs[idx]['type'] == 'ds_conv' :
|
|||
|
# # continue
|
|||
|
# compact_module_defs[idx]['filters'] = str(num)
|
|||
|
assert compact_module_defs[idx]['type'] == 'convolutional' or compact_module_defs[idx]['type'] == 'ds_conv'
|
|||
|
compact_module_defs[idx]['filters'] = str(num)
|
|||
|
# 改了网络结构,在这里相应的更改ds_conv通道数,先试试更改,如果后面不改,就在前面mask那里改改
|
|||
|
# assert compact_module_defs[idx]['type'] == 'ds_conv'
|
|||
|
# compact_module_defs[idx]['filters'] = str(num)
|
|||
|
# print("compact_module_defs:",compact_module_defs)
|
|||
|
|
|||
|
#%%
|
|||
|
#compact_model是剪枝之后的网络的真实结构(注意:上面的剪枝网络只是把那些需要剪枝的卷积层/BN层/激活层通道的权重置0了,并没有保存剪枝后的网络)
|
|||
|
# print("model.hyperparams:",model.hyperparams)
|
|||
|
compact_model = Darknet([model.hyperparams.copy()] + compact_module_defs).to(device)
|
|||
|
# compact_model_metric = eval_model(compact_model)
|
|||
|
# print("compact_model_metric:", "mAP", f'{compact_model_metric[2].mean():.6f}')
|
|||
|
# print("model.hyperparams:",model.hyperparams)
|
|||
|
# print("compact_model:", compact_model)
|
|||
|
# 计算参数量,MFLOPs
|
|||
|
compact_nparameters = obtain_num_parameters(compact_model)
|
|||
|
# print("compact_nparameters:", compact_nparameters) 18404269已经是剪完枝后的参数了
|
|||
|
# 为剪枝后的真实网络结构重新复制权重参数
|
|||
|
init_weights_from_loose_model(compact_model, pruned_model, CBL_idx, Conv_idx, CBLidx2mask)
|
|||
|
# 对比compact_model 与 prune_model每一层的权重
|
|||
|
# for name, parameters in pruned_model.named_parameters():
|
|||
|
# print(name, '1;', parameters.size())
|
|||
|
# for name, parameters in compact_model.named_parameters():
|
|||
|
# print(name, '2;', parameters.size())
|
|||
|
|
|||
|
# print("pruned_model:",pruned_model.state_dict())
|
|||
|
# print("compact_model:",compact_model.state_dict())
|
|||
|
|
|||
|
#%%
|
|||
|
random_input = torch.rand((1, 3, model.img_size, model.img_size)).to(device)
|
|||
|
|
|||
|
# 获取模型的推理时间
|
|||
|
def obtain_avg_forward_time(input, model, repeat=200):
|
|||
|
|
|||
|
model.eval()
|
|||
|
start = time.time()
|
|||
|
with torch.no_grad():
|
|||
|
for i in range(repeat):
|
|||
|
output = model(input)
|
|||
|
avg_infer_time = (time.time() - start) / repeat
|
|||
|
|
|||
|
return avg_infer_time, output
|
|||
|
|
|||
|
# 分别获取原始模型和剪枝后的模型的推理时间和输出
|
|||
|
pruned_forward_time, pruned_output = obtain_avg_forward_time(random_input, pruned_model)
|
|||
|
compact_forward_time, compact_output = obtain_avg_forward_time(random_input, compact_model)
|
|||
|
# print("pruned_forward_time:",pruned_forward_time,"---compact_forward_time:", compact_forward_time)
|
|||
|
# print("pruned_output:",pruned_output,"---compact_output:", compact_output)
|
|||
|
# 计算原始模型推理结果和剪枝后的模型的推理结果,如果差距比较大说明哪里错了
|
|||
|
# 先注释下面几行代码看能不能运行
|
|||
|
diff = (pruned_output-compact_output).abs().gt(0.001).sum().item()
|
|||
|
# print("diff:", diff)
|
|||
|
if diff > 0:
|
|||
|
print('Something wrong with the pruned model!')
|
|||
|
|
|||
|
#%%
|
|||
|
# 在测试集上测试剪枝后的模型, 并统计模型的参数数量
|
|||
|
compact_model_metric = eval_model(compact_model)
|
|||
|
|
|||
|
#%%
|
|||
|
# 比较剪枝前后参数数量的变化、指标性能的变化
|
|||
|
metric_table = [
|
|||
|
["Metric", "Before", "After"],
|
|||
|
["mAP", f'{origin_model_metric[2].mean():.6f}', f'{compact_model_metric[2].mean():.6f}'],
|
|||
|
["Parameters", f"{origin_nparameters}", f"{compact_nparameters}"],
|
|||
|
["Inference", f'{pruned_forward_time:.4f}', f'{compact_forward_time:.4f}']
|
|||
|
]
|
|||
|
print(AsciiTable(metric_table).table)
|
|||
|
|
|||
|
#%%
|
|||
|
# 生成剪枝后的cfg文件并保存模型
|
|||
|
pruned_cfg_name = opt.model_def.replace('/', f'/ds_1w_prune_{percent}_ckpt_99_05181112')
|
|||
|
pruned_cfg_file = write_cfg(pruned_cfg_name, [model.hyperparams.copy()] + compact_module_defs)
|
|||
|
print(f'Config file has been saved: {pruned_cfg_file}')
|
|||
|
|
|||
|
compact_model_name = opt.model.replace('/', f'/ds_1w_prune_{percent}_')
|
|||
|
torch.save(compact_model.state_dict(), compact_model_name)
|
|||
|
print(f'Compact model has been saved: {compact_model_name}')
|