import torch from terminaltables import AsciiTable from copy import deepcopy import numpy as np import torch.nn.functional as F def get_sr_flag(epoch, sr): # return epoch >= 5 and sr return sr def parse_module_defs(module_defs): CBL_idx = [] Conv_idx = [] for i, module_def in enumerate(module_defs): # 添加了 or module_def['type'] == 'ds_conv' 将ds_conv也纳入剪枝范围,如果不考虑的话就在下面的不需要剪枝那里添加 if module_def['type'] == 'convolutional' or module_def['type'] == 'ds_conv': if module_def['batch_normalize'] == '1': CBL_idx.append(i) #CBL_idx= 即为有卷积也有bn的层的idx #[0, 1, 2, 3, 5, 6, 7, 9, 10, 12, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29, 31, 32, 34, 35, 37, 38, # 39, 41, 42, 44, 45, 47, 48, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 75, 76, 77, # 78, 79, 80, 84, 87, 88, 89, 90, 91, 92, 96, 99, 100, 101, 102, 103, 104] else: #Conv_idx = [81, 93, 105] 即为有卷积没有bn的层的idx Conv_idx.append(i) ignore_idx = set() #哪些层不需要剪枝 for i, module_def in enumerate(module_defs): #将ds_conv纳入不剪枝范围 if module_def['type'] == 'ds_conv': ignore_idx.add(i-1) if module_def['type'] == 'shortcut': ignore_idx.add(i-1) identity_idx = (i + int(module_def['from'])) if module_defs[identity_idx]['type'] == 'convolutional': ignore_idx.add(identity_idx) elif module_defs[identity_idx]['type'] == 'shortcut': ignore_idx.add(identity_idx - 1) ignore_idx.add(84) ignore_idx.add(96) # print(ignore_idx) #ignore_idx={1, 3, 5, 7, 10, 12, 14, 17, 20, 23, 26, 29, 32, 35, 37, 39, 42, 45, 48, 51, 54, 57, 60, 62, 64, 67, 70, 73, 84, 96} 改结构前 # ignore_idx={1, 3, 5, 7, 10, 12, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29, 31, 32, 34, 35, 37, 38, 39, 41, 42, 44, 45, 47, # 48, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 84, 96} prune_idx = [idx for idx in CBL_idx if idx not in ignore_idx] #prune_idx =[0, 2, 6, 9, 13, 16, 19, 22, 25, 28, 31, 34, 38, 41, 44, 47, 50, 53, 56, 59, 63, 66, 69, 72, 75, # 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] 改结构前 # prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] # 返回CBL组件的id,单独的Conv层的id,以及需要被剪枝的层的id return CBL_idx, Conv_idx, prune_idx def gather_bn_weights(module_list, prune_idx): size_list = [module_list[idx][1].weight.data.shape[0] for idx in prune_idx] # 存储prune_idx对应层的 filter数量 # print('module_list[0][1].weight.data.shape[0] = ',module_list[0][1].weight.data.shape[0]) # print("size_list:",size_list) # ds_conv对应的size_list [32, 32, 64, 64, 512, 1024, 512, 1024, 512, 1024, 256, 512, 256, 512, 256, 512, 128, 256, 128, 256, 128, 256] #[32, 32, 64, 64, 128, 128, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 256, 256, 512, 512, 512, 512, 512, # 1024, 512, 1024, 512, 1024, 256, 512, 256, 512, 256, 512, 128, 256, 128, 256, 128, 256] bn_weights = torch.zeros(sum(size_list)) # print('bn_weights = ', bn_weights) # bn_weights == tensor([0., 0., 0., ..., 0., 0., 0.]) index = 0 for idx, size in zip(prune_idx, size_list): # print('idx = ' ,idx, ' | size = ',size) bn_weights[index:(index + size)] = module_list[idx][1].weight.data.abs().clone() index += size # print('bn_weights = ' ,bn_weights, ' | index = ',index) # print('module_list[0][1].weight.data.abs().clone()_len = ',len(module_list[0][1].weight.data.abs().clone())) # 获取CBL组件的BN层的权重,即Gamma参数,我们会根据这个参数来剪枝 return bn_weights def write_cfg(cfg_file, module_defs): with open(cfg_file, 'w') as f: for module_def in module_defs: f.write(f"[{module_def['type']}]\n") for key, value in module_def.items(): if key != 'type': f.write(f"{key}={value}\n") f.write("\n") return cfg_file class BNOptimizer(): @staticmethod def updateBN(sr_flag, module_list, s, prune_idx): if sr_flag: for idx in prune_idx: # Squential(Conv, BN, Lrelu) bn_module = module_list[idx][1] # 将批标准化模块(bn_module)的权重矩阵的梯度乘以一个缩放因子后添加到该权重矩阵的梯度上。 bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data)) # L1 def obtain_quantiles(bn_weights, num_quantile=5): sorted_bn_weights, i = torch.sort(bn_weights) total = sorted_bn_weights.shape[0] quantiles = sorted_bn_weights.tolist()[-1::-total//num_quantile][::-1] print("\nBN weights quantile:") quantile_table = [ [f'{i}/{num_quantile}' for i in range(1, num_quantile+1)], ["%.3f" % quantile for quantile in quantiles] ] print(AsciiTable(quantile_table).table) return quantiles def get_input_mask(module_defs, idx, CBLidx2mask): # print("CBLidx2mask : ", CBLidx2mask) if idx == 0: return np.ones(3) if module_defs[idx - 1]['type'] == 'convolutional': return CBLidx2mask[idx - 1] elif module_defs[idx - 1]['type'] == 'ds_conv': return CBLidx2mask[idx - 1] elif module_defs[idx - 1]['type'] == 'shortcut': return CBLidx2mask[idx - 2] elif module_defs[idx - 1]['type'] == 'route': route_in_idxs = [] for layer_i in module_defs[idx - 1]['layers'].split(","): if int(layer_i) < 0: route_in_idxs.append(idx - 1 + int(layer_i)) else: route_in_idxs.append(int(layer_i)) if len(route_in_idxs) == 1: return CBLidx2mask[route_in_idxs[0]] elif len(route_in_idxs) == 2: return np.concatenate([CBLidx2mask[in_idx - 1] for in_idx in route_in_idxs]) else: print("Something wrong with route module!") raise Exception def init_weights_from_loose_model(compact_model, loose_model, CBL_idx, Conv_idx, CBLidx2mask): # (14): Sequential( # (ds_conv_d_14): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128) # (ds_conv_p_14): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) # (batch_norm_14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # (leaky_14): LeakyReLU(negative_slope=0.1, inplace=True) # ) for idx in CBL_idx: compact_CBL = compact_model.module_list[idx] loose_CBL = loose_model.module_list[idx] # print("idx:",idx,"--compact_CBL :", compact_CBL, "--compact_CBL[1] :", compact_CBL[1],"--compact_CBL[2] :", compact_CBL[2]) # print("idx:",idx,"--loose_CBL :", loose_CBL, "--loose_CBL[1] :", loose_CBL[1],"--loose_CBL[2] :", loose_CBL[2]) # np.argwhere返回非0元素的索引,X[:,0]是numpy中数组的一种写法,表示对一个二维数组,取该二维数组第一维中的所有数据,第二维中取第0个数据 out_channel_idx = np.argwhere(CBLidx2mask[idx])[:, 0].tolist() # print("idx:",idx,"--out_channel_idx:", len(out_channel_idx)) # if (idx == 14): # print("idx:",idx,"--out_channel_idx:", out_channel_idx, "--len:",len(out_channel_idx)) # print("CBLidx2mask[idx] :",CBLidx2mask[idx],"--len:",len(CBLidx2mask[idx])) # print("loose_model.module_list[idx][1]:",loose_model.module_list[idx][1]) # 获取剪枝后的模型当前BN层的权重 if(compact_model.module_defs[idx]['type'] == 'ds_conv') : compact_bn, loose_bn = compact_CBL[2], loose_CBL[2] else: compact_bn, loose_bn = compact_CBL[1], loose_CBL[1] # if idx==14: # print("compact_bn.weight.data:",compact_bn.weight.data,"--len",len(compact_bn.weight.data)) # print("loose_bn.weight.data:",loose_bn.weight.data,"--len",len(loose_bn.weight.data)) # print("loose_bn.weight.data[out_channel_idx]:", loose_bn.weight.data[out_channel_idx],"--len",len(loose_bn.weight.data[out_channel_idx])) compact_bn.weight.data = loose_bn.weight.data[out_channel_idx].clone() compact_bn.bias.data = loose_bn.bias.data[out_channel_idx].clone() compact_bn.running_mean.data = loose_bn.running_mean.data[out_channel_idx].clone() compact_bn.running_var.data = loose_bn.running_var.data[out_channel_idx].clone() # 获取剪枝后的模型当前卷积层的权重,这和上一个卷积层的剪枝情况有关 # print("idx:",idx) input_mask = get_input_mask(loose_model.module_defs, idx, CBLidx2mask) in_channel_idx = np.argwhere(input_mask)[:, 0].tolist() # if (idx == 14): # print("idx:",idx,"--in_channel_idx:", in_channel_idx, "--len:",len(in_channel_idx)) # print("idx:",idx,"--in_channel_idx:",len(in_channel_idx)) if(compact_model.module_defs[idx]['type'] == 'ds_conv') : compact_conv, loose_conv = compact_CBL[0], loose_CBL[0] # compact_conv, loose_conv = compact_CBL[1], loose_CBL[1] compact_conv2, loose_conv2 = compact_CBL[1], loose_CBL[1] # 拷贝权重到剪枝后的模型中去 # print("loose_conv2.weight.data:",loose_conv2.weight.data,"--len:",len(loose_conv2.weight.data)) # print("loose_conv2.weight.data[:, in_channel_idx, :, :]:",loose_conv2.weight.data[:, in_channel_idx, :, :],"--len:",len(loose_conv2.weight.data[:, in_channel_idx, :, :])) # print("compact_conv.weight.data:",compact_conv.weight.data,"--len:",len(compact_conv.weight.data)) # print("compact_conv2.weight.data:",compact_conv2.weight.data,"--len:",len(compact_conv2.weight.data)) # print("loose_conv.weight.data:",loose_conv.weight.data,"--len:",len(loose_conv.weight.data)) # print("loose_conv.weight.data[:, in_channel_idx, :, :]:",loose_conv.weight.data[:, in_channel_idx, :, :],"--len:",len(loose_conv.weight.data[:, in_channel_idx, :, :])) # tmp1 = loose_conv.weight.data[:, in_channel_idx, :, :].clone() # compact_conv.weight.data = tmp1[in_channel_idx, :, :, :].clone() compact_conv.weight.data = loose_conv.weight.data.clone() compact_conv.bias.data = loose_conv.bias.data.clone() # tmp2 = loose_conv2.weight.data[:, in_channel_idx, :, :].clone() # compact_conv2.weight.data = tmp2[out_channel_idx, :, :, :].clone() compact_conv2.weight.data = loose_conv2.weight.data.clone() compact_conv2.bias.data = loose_conv2.bias.data.clone() else: compact_conv, loose_conv = compact_CBL[0], loose_CBL[0] # print("idx:",idx,"--compact_CBL[0]:",compact_CBL[0],"--loose_CBL[0]",loose_CBL[0]) #还有一个可能就是ds_conv中有两个卷积层都需要复制权重! # print("idx:",idx,"--compact_CBL[1]:",compact_CBL[1],"--loose_CBL[0]",loose_CBL[1]) # 拷贝权重到剪枝后的模型中去 tmp = loose_conv.weight.data[:, in_channel_idx, :, :].clone() compact_conv.weight.data = tmp[out_channel_idx, :, :, :].clone() for idx in Conv_idx: compact_conv = compact_model.module_list[idx][0] loose_conv = loose_model.module_list[idx][0] # 虽然当前层是不带BN的卷积层,但仍然和上一个层的剪枝情况是相关的 input_mask = get_input_mask(loose_model.module_defs, idx, CBLidx2mask) in_channel_idx = np.argwhere(input_mask)[:, 0].tolist() # 拷贝权重到剪枝后的模型中去 compact_conv.weight.data = loose_conv.weight.data[:, in_channel_idx, :, :].clone() compact_conv.bias.data = loose_conv.bias.data.clone() def prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask): # 先拷贝一份原始的模型参数 pruned_model = deepcopy(model) # 对需要剪枝的层分别处理 for idx in prune_idx: # 需要保留的通道 mask = torch.from_numpy(CBLidx2mask[idx]).cuda() # 获取BN层的gamma参数,即BN层的权重 bn_module = pruned_model.module_list[idx][1] # print("bn_module:", pruned_model.module_list[idx][1]) bn_module.weight.data.mul_(mask) # 获取保留下来的通道产生的激活值,注意是每个通道分别获取的 activation = F.leaky_relu((1 - mask) * bn_module.bias.data, 0.1) # 两个上采样层前的卷积层 next_idx_list = [idx + 1] if idx == 79: next_idx_list.append(84) elif idx == 91: next_idx_list.append(96) # print("idx:",idx,"--next_idx_list:",next_idx_list) # 对下一层进行处理 for next_idx in next_idx_list: # 当前层的BN剪枝之后会对下一个卷积层造成影响 next_conv = pruned_model.module_list[next_idx][0] # dim=(2,3)即在(w,h)维度上进行求和,因为是通道剪枝,一个通道对应着(w,h)这个矩形 conv_sum = next_conv.weight.data.sum(dim=(2, 3)) # 将卷积层的权重和激活值相乘获得剪枝后的每个通道的偏置,以更新下一个BN层或者下一个带偏置的卷积层的偏执(因为单独的卷积层是不会被剪枝的,所以只对偏置有影响 # print("idx:", {idx} , "| next_idx:" , {next_idx}) # print("conv_sum :", conv_sum) # print("activation :", activation) # print("--------------------------") offset = conv_sum.matmul(activation.reshape(-1, 1)).reshape(-1) if next_idx in CBL_idx: next_bn = pruned_model.module_list[next_idx][1] next_bn.running_mean.data.sub_(offset) else: next_conv.bias.data.add_(offset) bn_module.bias.data.mul_(mask) # 返回剪枝后的模型 return pruned_model def obtain_bn_mask(bn_module, thre): thre = thre.cuda() # ge(a, b)相当于 a>= b mask = bn_module.weight.data.abs().ge(thre).float() # print('thre = ',thre,"| mask = ", mask) # 返回通道是否需要剪枝的通道状态 return mask