from cgi import print_directory from models import * from utils.utils import * import torch import numpy as np from copy import deepcopy from test import evaluate from terminaltables import AsciiTable import time from utils.prune_utils import * class opt(): model_def = "config/yolov3-ds-person.cfg" # yolov3-ds8-person.cfg data_config = "config/smallperson.data" # smallperson.data model = 'checkpoints/yolov3_ckpt_99_05181112.pth' # checkpoints/yolov3_ckpt.pth' # res8 yolov3_ckpt_99_06081725.pth # 2*res8+res4 yolov3_ckpt_99_05181112 #%% device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Darknet(opt.model_def).to(device) #加载模型 model.load_state_dict(torch.load(opt.model)) #加载权重 # print(model) # for name in model.state_dict(): # print(name) data_config = parse_data_config(opt.data_config) valid_path = data_config["valid"] class_names = load_classes(data_config["names"]) # 用lambda表达式创建函数 eval_model = lambda model:evaluate(model, path=valid_path, iou_thres=0.5, conf_thres=0.01, nms_thres=0.1, img_size=model.img_size, batch_size=8) # print(eval_model) obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()]) # 获取最初的模型评估和参数量 origin_model_metric = eval_model(model) origin_nparameters = obtain_num_parameters(model) CBL_idx, Conv_idx, prune_idx= parse_module_defs(model.module_defs) # print(CBL_idx, Conv_idx, prune_idx) # CBL_idx = [0, 1, 2, 3, 5, 6, 7, 9, 10, 12, 13, 16, 19, 22, 25, 28, 31, 34, 37, 38, 41, 44, 47, 50, 53, 56, 59, 62, 63, 66, # 69, 72, 75, 76, 77, 78, 79, 80, 84, 87, 88, 89, 90, 91, 92, 96, 99, 100, 101, 102, 103, 104] # 问题是 CBL_idx中没有14就是ds_conv那一层 # Conv_idx = [81, 93, 105] # prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] #prune_idx = #[0, 2, 6, 9, 13, 16, 19, 22, 25, 28, 31, 34, 38, 41, 44, 47, 50, 53, 56, 59, 63, 66, 69, 72, 75, # 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] bn_weights = gather_bn_weights(model.module_list, prune_idx) # print("model.module_list[0]:",model.module_list[0],"---",len(model.module_list[0])) # print('bn_weights = ' ,bn_weights, "---",len(bn_weights)) sorted_bn = torch.sort(bn_weights)[0] # 对bn从小到大排序 # print('sorted_bn = ' ,sorted_bn) # 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限) highest_thre = [] for idx in prune_idx: highest_thre.append(model.module_list[idx][1].weight.data.abs().max().item()) highest_thre = min(highest_thre) # 找到highest_thre对应的下标对应的百分比 percent_limit = (sorted_bn==highest_thre).nonzero().item()/len(bn_weights) print(f'Threshold should be less than {highest_thre:.4f}.') print(f'The corresponding prune ratio is {percent_limit:.3f}.') #%% def prune_and_eval(model, sorted_bn, percent=.0): model_copy = deepcopy(model) thre_index = int(len(sorted_bn) * percent) # print('thre_index = ', thre_index) #thre_index = 11369 thre = sorted_bn[thre_index] # print('thre = ', thre) #thre = tensor(0.0925) # print(model_copy) print(f'Channels with Gamma value less than {thre:.4f} are pruned!') remain_num = 0 for idx in prune_idx: # prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] # print("idx:",idx) bn_module = model_copy.module_list[idx][1] mask = obtain_bn_mask(bn_module, thre) #生成mask # print(idx, ': ', len(mask)) # print(idx, ':', len(mask), ':', mask) remain_num += int(mask.sum()) # BN层的权重(gamma)乘以这个mask,就相当于剪枝了 bn_module.weight.data.mul_(mask) # 用mask对原始权重进行操作 # print("remain_num:",remain_num) # print(model_copy) # print(bn_module.weight.data) # model_copy.module_list[0][1] = BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # print(model_copy.module_list[0][1].weight.data) """ model_copy.module_list[0][1].weight.data = tensor([ 1.6235, 0.3670, 0.6683, 0.2203, -0.2113, 1.6356, 0.0717, 0.5802, 0.1437, -0.3640, 0.2322, 0.2651, 0.7316, 0.6135, 1.6100, 0.8620, 0.1987, 0.5357, 0.2006, 0.2127, 0.7190, -1.1396, -0.2585, -0.4673, 0.0498, 0.5148, 0.7377, 0.3179, 1.2934, 1.1743, 0.2840, 0.2782], device='cuda:0') """ mAP = eval_model(model_copy)[2].mean() # 搞明白为什么 剪枝之后sample_metrics = [] 为空列表了 : 剪枝率过高会为空 # print(mAP) print(f'Number of channels has been reduced from {len(sorted_bn)} to {remain_num}') print(f'Prune ratio: {1-remain_num/len(sorted_bn):.3f}') print(f'mAP of the pruned model is {mAP:.4f}') return thre #调用上面的函数 percent = 0.85 # 0.85 threshold = prune_and_eval(model, sorted_bn, percent) # print(threshold) #%% def obtain_filters_mask(model, thre, CBL_idx, prune_idx): pruned = 0 total = 0 num_filters = [] filters_mask = [] for idx in CBL_idx: if(model.module_defs[idx]['type'] == 'ds_conv'): bn_module = model.module_list[idx][2] else: bn_module = model.module_list[idx][1] # print("idx",idx,"--bn_module :",model.module_list[idx][1]) # 如果idx是在剪枝下标的列表中,就执行剪枝 if idx in prune_idx: # prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] mask = obtain_bn_mask(bn_module, thre).cpu().numpy() # 保留的通道数 remain = int(mask.sum()) # 剪掉的通道数 pruned = pruned + mask.shape[0] - remain if remain == 0: print("Channels would be all pruned!") raise Exception print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t ' f'remaining channel: {remain:>4d}') else: # 不用剪枝就全部保留 mask = np.ones(bn_module.weight.data.shape) remain = mask.shape[0] # print("idx:",idx,"--mask:",mask,"--lenofmask",len(mask)) # 怀疑这里因为ds_conv有两个卷积层 可能出现问题 但是14层的ds_conv mask长度是256 因为算的是bn层的mask,所以为256,这里值得注意!!明天来检查一下! total += mask.shape[0] num_filters.append(remain) # 剪枝后还存在的滤波器 filters_mask.append(mask.copy()) # 剪枝的掩码mask # print("num_filters:",len(num_filters),"CBL_idx:", len(CBL_idx)) 都是72 prune_ratio = pruned / total print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}') return num_filters, filters_mask #调用上面的函数 num_filters, filters_mask = obtain_filters_mask(model, threshold, CBL_idx, prune_idx) # print("num_filters : ", num_filters) # num_filters = [29, 64, 29, 64, 128, 59, 128, 61, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 128, 256, 512, # 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 256, 512, 1024, 512, 1024, 512, 1024, 512, 1024, 512, 1024, # 264, 422, 237, 441, 263, 139, 256, 226, 352, 179, 278, 137, 215, 128, 111, 182, 90, 140, 54, 220] #%% # 映射成一个字典,idx->mask CBLidx2mask = {idx: mask for idx, mask in zip(CBL_idx, filters_mask)} # print("CBLidx2mask:", CBLidx2mask) # 获得剪枝后的模型 pruned_model = prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask) # print("pruned_model:",pruned_model) # 对剪枝后的模型进行评价 pruned_model_metric = eval_model(pruned_model) print("mAP", f'{origin_model_metric[2].mean():.6f}', f'{pruned_model_metric[2].mean():.6f}') # pruned_nparameters = obtain_num_parameters(pruned_model) # print("pruned_nparameters:",pruned_nparameters) #%% # 拷贝一份原始模型的参数 compact_module_defs = deepcopy(model.module_defs) # 遍历需要剪枝的CBL模块,将通道数设置为剪枝后的通道数 for idx, num in zip(CBL_idx, num_filters): # if compact_module_defs[idx]['type'] == 'ds_conv' : # # continue # compact_module_defs[idx]['filters'] = str(num) assert compact_module_defs[idx]['type'] == 'convolutional' or compact_module_defs[idx]['type'] == 'ds_conv' compact_module_defs[idx]['filters'] = str(num) # 改了网络结构,在这里相应的更改ds_conv通道数,先试试更改,如果后面不改,就在前面mask那里改改 # assert compact_module_defs[idx]['type'] == 'ds_conv' # compact_module_defs[idx]['filters'] = str(num) # print("compact_module_defs:",compact_module_defs) #%% #compact_model是剪枝之后的网络的真实结构(注意:上面的剪枝网络只是把那些需要剪枝的卷积层/BN层/激活层通道的权重置0了,并没有保存剪枝后的网络) # print("model.hyperparams:",model.hyperparams) compact_model = Darknet([model.hyperparams.copy()] + compact_module_defs).to(device) # compact_model_metric = eval_model(compact_model) # print("compact_model_metric:", "mAP", f'{compact_model_metric[2].mean():.6f}') # print("model.hyperparams:",model.hyperparams) # print("compact_model:", compact_model) # 计算参数量,MFLOPs compact_nparameters = obtain_num_parameters(compact_model) # print("compact_nparameters:", compact_nparameters) 18404269已经是剪完枝后的参数了 # 为剪枝后的真实网络结构重新复制权重参数 init_weights_from_loose_model(compact_model, pruned_model, CBL_idx, Conv_idx, CBLidx2mask) # 对比compact_model 与 prune_model每一层的权重 # for name, parameters in pruned_model.named_parameters(): # print(name, '1;', parameters.size()) # for name, parameters in compact_model.named_parameters(): # print(name, '2;', parameters.size()) # print("pruned_model:",pruned_model.state_dict()) # print("compact_model:",compact_model.state_dict()) #%% random_input = torch.rand((1, 3, model.img_size, model.img_size)).to(device) # 获取模型的推理时间 def obtain_avg_forward_time(input, model, repeat=200): model.eval() start = time.time() with torch.no_grad(): for i in range(repeat): output = model(input) avg_infer_time = (time.time() - start) / repeat return avg_infer_time, output # 分别获取原始模型和剪枝后的模型的推理时间和输出 pruned_forward_time, pruned_output = obtain_avg_forward_time(random_input, pruned_model) compact_forward_time, compact_output = obtain_avg_forward_time(random_input, compact_model) # print("pruned_forward_time:",pruned_forward_time,"---compact_forward_time:", compact_forward_time) # print("pruned_output:",pruned_output,"---compact_output:", compact_output) # 计算原始模型推理结果和剪枝后的模型的推理结果,如果差距比较大说明哪里错了 # 先注释下面几行代码看能不能运行 diff = (pruned_output-compact_output).abs().gt(0.001).sum().item() # print("diff:", diff) if diff > 0: print('Something wrong with the pruned model!') #%% # 在测试集上测试剪枝后的模型, 并统计模型的参数数量 compact_model_metric = eval_model(compact_model) #%% # 比较剪枝前后参数数量的变化、指标性能的变化 metric_table = [ ["Metric", "Before", "After"], ["mAP", f'{origin_model_metric[2].mean():.6f}', f'{compact_model_metric[2].mean():.6f}'], ["Parameters", f"{origin_nparameters}", f"{compact_nparameters}"], ["Inference", f'{pruned_forward_time:.4f}', f'{compact_forward_time:.4f}'] ] print(AsciiTable(metric_table).table) #%% # 生成剪枝后的cfg文件并保存模型 pruned_cfg_name = opt.model_def.replace('/', f'/ds_1w_prune_{percent}_ckpt_99_05181112') pruned_cfg_file = write_cfg(pruned_cfg_name, [model.hyperparams.copy()] + compact_module_defs) print(f'Config file has been saved: {pruned_cfg_file}') compact_model_name = opt.model.replace('/', f'/ds_1w_prune_{percent}_') torch.save(compact_model.state_dict(), compact_model_name) print(f'Compact model has been saved: {compact_model_name}')