YOLOv3-model-pruning/utils/prune_utils.py

291 lines
15 KiB
Python
Raw Permalink Normal View History

2024-06-25 14:07:50 +08:00
import torch
from terminaltables import AsciiTable
from copy import deepcopy
import numpy as np
import torch.nn.functional as F
def get_sr_flag(epoch, sr):
# return epoch >= 5 and sr
return sr
def parse_module_defs(module_defs):
CBL_idx = []
Conv_idx = []
for i, module_def in enumerate(module_defs):
# 添加了 or module_def['type'] == 'ds_conv' 将ds_conv也纳入剪枝范围如果不考虑的话就在下面的不需要剪枝那里添加
if module_def['type'] == 'convolutional' or module_def['type'] == 'ds_conv':
if module_def['batch_normalize'] == '1':
CBL_idx.append(i)
#CBL_idx= 即为有卷积也有bn的层的idx
#[0, 1, 2, 3, 5, 6, 7, 9, 10, 12, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29, 31, 32, 34, 35, 37, 38,
# 39, 41, 42, 44, 45, 47, 48, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 75, 76, 77,
# 78, 79, 80, 84, 87, 88, 89, 90, 91, 92, 96, 99, 100, 101, 102, 103, 104]
else:
#Conv_idx = [81, 93, 105] 即为有卷积没有bn的层的idx
Conv_idx.append(i)
ignore_idx = set() #哪些层不需要剪枝
for i, module_def in enumerate(module_defs):
#将ds_conv纳入不剪枝范围
if module_def['type'] == 'ds_conv':
ignore_idx.add(i-1)
if module_def['type'] == 'shortcut':
ignore_idx.add(i-1)
identity_idx = (i + int(module_def['from']))
if module_defs[identity_idx]['type'] == 'convolutional':
ignore_idx.add(identity_idx)
elif module_defs[identity_idx]['type'] == 'shortcut':
ignore_idx.add(identity_idx - 1)
ignore_idx.add(84)
ignore_idx.add(96)
# print(ignore_idx)
#ignore_idx={1, 3, 5, 7, 10, 12, 14, 17, 20, 23, 26, 29, 32, 35, 37, 39, 42, 45, 48, 51, 54, 57, 60, 62, 64, 67, 70, 73, 84, 96} 改结构前
# ignore_idx={1, 3, 5, 7, 10, 12, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29, 31, 32, 34, 35, 37, 38, 39, 41, 42, 44, 45, 47,
# 48, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 84, 96}
prune_idx = [idx for idx in CBL_idx if idx not in ignore_idx]
#prune_idx =[0, 2, 6, 9, 13, 16, 19, 22, 25, 28, 31, 34, 38, 41, 44, 47, 50, 53, 56, 59, 63, 66, 69, 72, 75,
# 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] 改结构前
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
# 返回CBL组件的id单独的Conv层的id以及需要被剪枝的层的id
return CBL_idx, Conv_idx, prune_idx
def gather_bn_weights(module_list, prune_idx):
size_list = [module_list[idx][1].weight.data.shape[0] for idx in prune_idx] # 存储prune_idx对应层的 filter数量
# print('module_list[0][1].weight.data.shape[0] = ',module_list[0][1].weight.data.shape[0])
# print("size_list:",size_list)
# ds_conv对应的size_list [32, 32, 64, 64, 512, 1024, 512, 1024, 512, 1024, 256, 512, 256, 512, 256, 512, 128, 256, 128, 256, 128, 256]
#[32, 32, 64, 64, 128, 128, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 256, 256, 512, 512, 512, 512, 512,
# 1024, 512, 1024, 512, 1024, 256, 512, 256, 512, 256, 512, 128, 256, 128, 256, 128, 256]
bn_weights = torch.zeros(sum(size_list))
# print('bn_weights = ', bn_weights)
# bn_weights == tensor([0., 0., 0., ..., 0., 0., 0.])
index = 0
for idx, size in zip(prune_idx, size_list):
# print('idx = ' ,idx, ' | size = ',size)
bn_weights[index:(index + size)] = module_list[idx][1].weight.data.abs().clone()
index += size
# print('bn_weights = ' ,bn_weights, ' | index = ',index)
# print('module_list[0][1].weight.data.abs().clone()_len = ',len(module_list[0][1].weight.data.abs().clone()))
# 获取CBL组件的BN层的权重即Gamma参数我们会根据这个参数来剪枝
return bn_weights
def write_cfg(cfg_file, module_defs):
with open(cfg_file, 'w') as f:
for module_def in module_defs:
f.write(f"[{module_def['type']}]\n")
for key, value in module_def.items():
if key != 'type':
f.write(f"{key}={value}\n")
f.write("\n")
return cfg_file
class BNOptimizer():
@staticmethod
def updateBN(sr_flag, module_list, s, prune_idx):
if sr_flag:
for idx in prune_idx:
# Squential(Conv, BN, Lrelu)
bn_module = module_list[idx][1]
# 将批标准化模块(bn_module)的权重矩阵的梯度乘以一个缩放因子后添加到该权重矩阵的梯度上。
bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data)) # L1
def obtain_quantiles(bn_weights, num_quantile=5):
sorted_bn_weights, i = torch.sort(bn_weights)
total = sorted_bn_weights.shape[0]
quantiles = sorted_bn_weights.tolist()[-1::-total//num_quantile][::-1]
print("\nBN weights quantile:")
quantile_table = [
[f'{i}/{num_quantile}' for i in range(1, num_quantile+1)],
["%.3f" % quantile for quantile in quantiles]
]
print(AsciiTable(quantile_table).table)
return quantiles
def get_input_mask(module_defs, idx, CBLidx2mask):
# print("CBLidx2mask ", CBLidx2mask)
if idx == 0:
return np.ones(3)
if module_defs[idx - 1]['type'] == 'convolutional':
return CBLidx2mask[idx - 1]
elif module_defs[idx - 1]['type'] == 'ds_conv':
return CBLidx2mask[idx - 1]
elif module_defs[idx - 1]['type'] == 'shortcut':
return CBLidx2mask[idx - 2]
elif module_defs[idx - 1]['type'] == 'route':
route_in_idxs = []
for layer_i in module_defs[idx - 1]['layers'].split(","):
if int(layer_i) < 0:
route_in_idxs.append(idx - 1 + int(layer_i))
else:
route_in_idxs.append(int(layer_i))
if len(route_in_idxs) == 1:
return CBLidx2mask[route_in_idxs[0]]
elif len(route_in_idxs) == 2:
return np.concatenate([CBLidx2mask[in_idx - 1] for in_idx in route_in_idxs])
else:
print("Something wrong with route module!")
raise Exception
def init_weights_from_loose_model(compact_model, loose_model, CBL_idx, Conv_idx, CBLidx2mask):
# (14): Sequential(
# (ds_conv_d_14): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
# (ds_conv_p_14): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
# (batch_norm_14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (leaky_14): LeakyReLU(negative_slope=0.1, inplace=True)
# )
for idx in CBL_idx:
compact_CBL = compact_model.module_list[idx]
loose_CBL = loose_model.module_list[idx]
# print("idx:",idx,"--compact_CBL :", compact_CBL, "--compact_CBL[1] :", compact_CBL[1],"--compact_CBL[2] :", compact_CBL[2])
# print("idx:",idx,"--loose_CBL :", loose_CBL, "--loose_CBL[1] :", loose_CBL[1],"--loose_CBL[2] :", loose_CBL[2])
# np.argwhere返回非0元素的索引X[:,0]是numpy中数组的一种写法表示对一个二维数组取该二维数组第一维中的所有数据第二维中取第0个数据
out_channel_idx = np.argwhere(CBLidx2mask[idx])[:, 0].tolist()
# print("idx:",idx,"--out_channel_idx:", len(out_channel_idx))
# if (idx == 14):
# print("idx:",idx,"--out_channel_idx:", out_channel_idx, "--len:",len(out_channel_idx))
# print("CBLidx2mask[idx] :",CBLidx2mask[idx],"--len:",len(CBLidx2mask[idx]))
# print("loose_model.module_list[idx][1]:",loose_model.module_list[idx][1])
# 获取剪枝后的模型当前BN层的权重
if(compact_model.module_defs[idx]['type'] == 'ds_conv') :
compact_bn, loose_bn = compact_CBL[2], loose_CBL[2]
else:
compact_bn, loose_bn = compact_CBL[1], loose_CBL[1]
# if idx==14:
# print("compact_bn.weight.data:",compact_bn.weight.data,"--len",len(compact_bn.weight.data))
# print("loose_bn.weight.data:",loose_bn.weight.data,"--len",len(loose_bn.weight.data))
# print("loose_bn.weight.data[out_channel_idx]:", loose_bn.weight.data[out_channel_idx],"--len",len(loose_bn.weight.data[out_channel_idx]))
compact_bn.weight.data = loose_bn.weight.data[out_channel_idx].clone()
compact_bn.bias.data = loose_bn.bias.data[out_channel_idx].clone()
compact_bn.running_mean.data = loose_bn.running_mean.data[out_channel_idx].clone()
compact_bn.running_var.data = loose_bn.running_var.data[out_channel_idx].clone()
# 获取剪枝后的模型当前卷积层的权重,这和上一个卷积层的剪枝情况有关
# print("idx:",idx)
input_mask = get_input_mask(loose_model.module_defs, idx, CBLidx2mask)
in_channel_idx = np.argwhere(input_mask)[:, 0].tolist()
# if (idx == 14):
# print("idx:",idx,"--in_channel_idx:", in_channel_idx, "--len:",len(in_channel_idx))
# print("idx:",idx,"--in_channel_idx:",len(in_channel_idx))
if(compact_model.module_defs[idx]['type'] == 'ds_conv') :
compact_conv, loose_conv = compact_CBL[0], loose_CBL[0]
# compact_conv, loose_conv = compact_CBL[1], loose_CBL[1]
compact_conv2, loose_conv2 = compact_CBL[1], loose_CBL[1]
# 拷贝权重到剪枝后的模型中去
# print("loose_conv2.weight.data:",loose_conv2.weight.data,"--len:",len(loose_conv2.weight.data))
# print("loose_conv2.weight.data[:, in_channel_idx, :, :]:",loose_conv2.weight.data[:, in_channel_idx, :, :],"--len:",len(loose_conv2.weight.data[:, in_channel_idx, :, :]))
# print("compact_conv.weight.data:",compact_conv.weight.data,"--len:",len(compact_conv.weight.data))
# print("compact_conv2.weight.data:",compact_conv2.weight.data,"--len:",len(compact_conv2.weight.data))
# print("loose_conv.weight.data:",loose_conv.weight.data,"--len:",len(loose_conv.weight.data))
# print("loose_conv.weight.data[:, in_channel_idx, :, :]:",loose_conv.weight.data[:, in_channel_idx, :, :],"--len:",len(loose_conv.weight.data[:, in_channel_idx, :, :]))
# tmp1 = loose_conv.weight.data[:, in_channel_idx, :, :].clone()
# compact_conv.weight.data = tmp1[in_channel_idx, :, :, :].clone()
compact_conv.weight.data = loose_conv.weight.data.clone()
compact_conv.bias.data = loose_conv.bias.data.clone()
# tmp2 = loose_conv2.weight.data[:, in_channel_idx, :, :].clone()
# compact_conv2.weight.data = tmp2[out_channel_idx, :, :, :].clone()
compact_conv2.weight.data = loose_conv2.weight.data.clone()
compact_conv2.bias.data = loose_conv2.bias.data.clone()
else:
compact_conv, loose_conv = compact_CBL[0], loose_CBL[0]
# print("idx:",idx,"--compact_CBL[0]:",compact_CBL[0],"--loose_CBL[0]",loose_CBL[0]) #还有一个可能就是ds_conv中有两个卷积层都需要复制权重
# print("idx:",idx,"--compact_CBL[1]:",compact_CBL[1],"--loose_CBL[0]",loose_CBL[1])
# 拷贝权重到剪枝后的模型中去
tmp = loose_conv.weight.data[:, in_channel_idx, :, :].clone()
compact_conv.weight.data = tmp[out_channel_idx, :, :, :].clone()
for idx in Conv_idx:
compact_conv = compact_model.module_list[idx][0]
loose_conv = loose_model.module_list[idx][0]
# 虽然当前层是不带BN的卷积层但仍然和上一个层的剪枝情况是相关的
input_mask = get_input_mask(loose_model.module_defs, idx, CBLidx2mask)
in_channel_idx = np.argwhere(input_mask)[:, 0].tolist()
# 拷贝权重到剪枝后的模型中去
compact_conv.weight.data = loose_conv.weight.data[:, in_channel_idx, :, :].clone()
compact_conv.bias.data = loose_conv.bias.data.clone()
def prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask):
# 先拷贝一份原始的模型参数
pruned_model = deepcopy(model)
# 对需要剪枝的层分别处理
for idx in prune_idx:
# 需要保留的通道
mask = torch.from_numpy(CBLidx2mask[idx]).cuda()
# 获取BN层的gamma参数即BN层的权重
bn_module = pruned_model.module_list[idx][1]
# print("bn_module:", pruned_model.module_list[idx][1])
bn_module.weight.data.mul_(mask)
# 获取保留下来的通道产生的激活值,注意是每个通道分别获取的
activation = F.leaky_relu((1 - mask) * bn_module.bias.data, 0.1)
# 两个上采样层前的卷积层
next_idx_list = [idx + 1]
if idx == 79:
next_idx_list.append(84)
elif idx == 91:
next_idx_list.append(96)
# print("idx:",idx,"--next_idx_list:",next_idx_list)
# 对下一层进行处理
for next_idx in next_idx_list:
# 当前层的BN剪枝之后会对下一个卷积层造成影响
next_conv = pruned_model.module_list[next_idx][0]
# dim=(2,3)即在(w,h)维度上进行求和,因为是通道剪枝,一个通道对应着(w,h)这个矩形
conv_sum = next_conv.weight.data.sum(dim=(2, 3))
# 将卷积层的权重和激活值相乘获得剪枝后的每个通道的偏置以更新下一个BN层或者下一个带偏置的卷积层的偏执因为单独的卷积层是不会被剪枝的所以只对偏置有影响
# print("idx:", {idx} , "| next_idx:" , {next_idx})
# print("conv_sum :", conv_sum)
# print("activation :", activation)
# print("--------------------------")
offset = conv_sum.matmul(activation.reshape(-1, 1)).reshape(-1)
if next_idx in CBL_idx:
next_bn = pruned_model.module_list[next_idx][1]
next_bn.running_mean.data.sub_(offset)
else:
next_conv.bias.data.add_(offset)
bn_module.bias.data.mul_(mask)
# 返回剪枝后的模型
return pruned_model
def obtain_bn_mask(bn_module, thre):
thre = thre.cuda()
# ge(a, b)相当于 a>= b
mask = bn_module.weight.data.abs().ge(thre).float()
# print('thre = ',thre,"| mask = ", mask)
# 返回通道是否需要剪枝的通道状态
return mask