291 lines
15 KiB
Python
291 lines
15 KiB
Python
import torch
|
||
from terminaltables import AsciiTable
|
||
from copy import deepcopy
|
||
import numpy as np
|
||
import torch.nn.functional as F
|
||
|
||
|
||
def get_sr_flag(epoch, sr):
|
||
# return epoch >= 5 and sr
|
||
return sr
|
||
|
||
|
||
def parse_module_defs(module_defs):
|
||
|
||
CBL_idx = []
|
||
Conv_idx = []
|
||
for i, module_def in enumerate(module_defs):
|
||
# 添加了 or module_def['type'] == 'ds_conv' 将ds_conv也纳入剪枝范围,如果不考虑的话就在下面的不需要剪枝那里添加
|
||
if module_def['type'] == 'convolutional' or module_def['type'] == 'ds_conv':
|
||
if module_def['batch_normalize'] == '1':
|
||
CBL_idx.append(i)
|
||
#CBL_idx= 即为有卷积也有bn的层的idx
|
||
#[0, 1, 2, 3, 5, 6, 7, 9, 10, 12, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29, 31, 32, 34, 35, 37, 38,
|
||
# 39, 41, 42, 44, 45, 47, 48, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 75, 76, 77,
|
||
# 78, 79, 80, 84, 87, 88, 89, 90, 91, 92, 96, 99, 100, 101, 102, 103, 104]
|
||
else:
|
||
#Conv_idx = [81, 93, 105] 即为有卷积没有bn的层的idx
|
||
Conv_idx.append(i)
|
||
|
||
ignore_idx = set() #哪些层不需要剪枝
|
||
for i, module_def in enumerate(module_defs):
|
||
#将ds_conv纳入不剪枝范围
|
||
if module_def['type'] == 'ds_conv':
|
||
ignore_idx.add(i-1)
|
||
|
||
if module_def['type'] == 'shortcut':
|
||
ignore_idx.add(i-1)
|
||
identity_idx = (i + int(module_def['from']))
|
||
if module_defs[identity_idx]['type'] == 'convolutional':
|
||
ignore_idx.add(identity_idx)
|
||
elif module_defs[identity_idx]['type'] == 'shortcut':
|
||
ignore_idx.add(identity_idx - 1)
|
||
|
||
ignore_idx.add(84)
|
||
ignore_idx.add(96)
|
||
# print(ignore_idx)
|
||
#ignore_idx={1, 3, 5, 7, 10, 12, 14, 17, 20, 23, 26, 29, 32, 35, 37, 39, 42, 45, 48, 51, 54, 57, 60, 62, 64, 67, 70, 73, 84, 96} 改结构前
|
||
# ignore_idx={1, 3, 5, 7, 10, 12, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29, 31, 32, 34, 35, 37, 38, 39, 41, 42, 44, 45, 47,
|
||
# 48, 50, 51, 53, 54, 56, 57, 59, 60, 62, 63, 64, 66, 67, 69, 70, 72, 73, 84, 96}
|
||
prune_idx = [idx for idx in CBL_idx if idx not in ignore_idx]
|
||
#prune_idx =[0, 2, 6, 9, 13, 16, 19, 22, 25, 28, 31, 34, 38, 41, 44, 47, 50, 53, 56, 59, 63, 66, 69, 72, 75,
|
||
# 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104] 改结构前
|
||
# prune_idx = [0, 2, 6, 9, 75, 76, 77, 78, 79, 80, 87, 88, 89, 90, 91, 92, 99, 100, 101, 102, 103, 104]
|
||
# 返回CBL组件的id,单独的Conv层的id,以及需要被剪枝的层的id
|
||
return CBL_idx, Conv_idx, prune_idx
|
||
|
||
|
||
def gather_bn_weights(module_list, prune_idx):
|
||
|
||
size_list = [module_list[idx][1].weight.data.shape[0] for idx in prune_idx] # 存储prune_idx对应层的 filter数量
|
||
# print('module_list[0][1].weight.data.shape[0] = ',module_list[0][1].weight.data.shape[0])
|
||
# print("size_list:",size_list)
|
||
# ds_conv对应的size_list [32, 32, 64, 64, 512, 1024, 512, 1024, 512, 1024, 256, 512, 256, 512, 256, 512, 128, 256, 128, 256, 128, 256]
|
||
#[32, 32, 64, 64, 128, 128, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 256, 256, 512, 512, 512, 512, 512,
|
||
# 1024, 512, 1024, 512, 1024, 256, 512, 256, 512, 256, 512, 128, 256, 128, 256, 128, 256]
|
||
|
||
bn_weights = torch.zeros(sum(size_list))
|
||
# print('bn_weights = ', bn_weights)
|
||
# bn_weights == tensor([0., 0., 0., ..., 0., 0., 0.])
|
||
index = 0
|
||
|
||
for idx, size in zip(prune_idx, size_list):
|
||
# print('idx = ' ,idx, ' | size = ',size)
|
||
bn_weights[index:(index + size)] = module_list[idx][1].weight.data.abs().clone()
|
||
index += size
|
||
# print('bn_weights = ' ,bn_weights, ' | index = ',index)
|
||
# print('module_list[0][1].weight.data.abs().clone()_len = ',len(module_list[0][1].weight.data.abs().clone()))
|
||
# 获取CBL组件的BN层的权重,即Gamma参数,我们会根据这个参数来剪枝
|
||
return bn_weights
|
||
|
||
|
||
def write_cfg(cfg_file, module_defs):
|
||
|
||
with open(cfg_file, 'w') as f:
|
||
for module_def in module_defs:
|
||
f.write(f"[{module_def['type']}]\n")
|
||
for key, value in module_def.items():
|
||
if key != 'type':
|
||
f.write(f"{key}={value}\n")
|
||
f.write("\n")
|
||
return cfg_file
|
||
|
||
|
||
class BNOptimizer():
|
||
|
||
@staticmethod
|
||
def updateBN(sr_flag, module_list, s, prune_idx):
|
||
if sr_flag:
|
||
for idx in prune_idx:
|
||
# Squential(Conv, BN, Lrelu)
|
||
bn_module = module_list[idx][1]
|
||
# 将批标准化模块(bn_module)的权重矩阵的梯度乘以一个缩放因子后添加到该权重矩阵的梯度上。
|
||
bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data)) # L1
|
||
|
||
|
||
def obtain_quantiles(bn_weights, num_quantile=5):
|
||
|
||
sorted_bn_weights, i = torch.sort(bn_weights)
|
||
total = sorted_bn_weights.shape[0]
|
||
quantiles = sorted_bn_weights.tolist()[-1::-total//num_quantile][::-1]
|
||
print("\nBN weights quantile:")
|
||
quantile_table = [
|
||
[f'{i}/{num_quantile}' for i in range(1, num_quantile+1)],
|
||
["%.3f" % quantile for quantile in quantiles]
|
||
]
|
||
print(AsciiTable(quantile_table).table)
|
||
|
||
return quantiles
|
||
|
||
|
||
def get_input_mask(module_defs, idx, CBLidx2mask):
|
||
# print("CBLidx2mask : ", CBLidx2mask)
|
||
if idx == 0:
|
||
return np.ones(3)
|
||
|
||
if module_defs[idx - 1]['type'] == 'convolutional':
|
||
return CBLidx2mask[idx - 1]
|
||
elif module_defs[idx - 1]['type'] == 'ds_conv':
|
||
return CBLidx2mask[idx - 1]
|
||
elif module_defs[idx - 1]['type'] == 'shortcut':
|
||
return CBLidx2mask[idx - 2]
|
||
elif module_defs[idx - 1]['type'] == 'route':
|
||
route_in_idxs = []
|
||
for layer_i in module_defs[idx - 1]['layers'].split(","):
|
||
if int(layer_i) < 0:
|
||
route_in_idxs.append(idx - 1 + int(layer_i))
|
||
else:
|
||
route_in_idxs.append(int(layer_i))
|
||
if len(route_in_idxs) == 1:
|
||
return CBLidx2mask[route_in_idxs[0]]
|
||
elif len(route_in_idxs) == 2:
|
||
return np.concatenate([CBLidx2mask[in_idx - 1] for in_idx in route_in_idxs])
|
||
else:
|
||
print("Something wrong with route module!")
|
||
raise Exception
|
||
|
||
|
||
def init_weights_from_loose_model(compact_model, loose_model, CBL_idx, Conv_idx, CBLidx2mask):
|
||
|
||
# (14): Sequential(
|
||
# (ds_conv_d_14): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
|
||
# (ds_conv_p_14): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
|
||
# (batch_norm_14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
||
# (leaky_14): LeakyReLU(negative_slope=0.1, inplace=True)
|
||
# )
|
||
|
||
for idx in CBL_idx:
|
||
compact_CBL = compact_model.module_list[idx]
|
||
loose_CBL = loose_model.module_list[idx]
|
||
# print("idx:",idx,"--compact_CBL :", compact_CBL, "--compact_CBL[1] :", compact_CBL[1],"--compact_CBL[2] :", compact_CBL[2])
|
||
# print("idx:",idx,"--loose_CBL :", loose_CBL, "--loose_CBL[1] :", loose_CBL[1],"--loose_CBL[2] :", loose_CBL[2])
|
||
# np.argwhere返回非0元素的索引,X[:,0]是numpy中数组的一种写法,表示对一个二维数组,取该二维数组第一维中的所有数据,第二维中取第0个数据
|
||
out_channel_idx = np.argwhere(CBLidx2mask[idx])[:, 0].tolist()
|
||
# print("idx:",idx,"--out_channel_idx:", len(out_channel_idx))
|
||
# if (idx == 14):
|
||
# print("idx:",idx,"--out_channel_idx:", out_channel_idx, "--len:",len(out_channel_idx))
|
||
# print("CBLidx2mask[idx] :",CBLidx2mask[idx],"--len:",len(CBLidx2mask[idx]))
|
||
# print("loose_model.module_list[idx][1]:",loose_model.module_list[idx][1])
|
||
|
||
|
||
# 获取剪枝后的模型当前BN层的权重
|
||
if(compact_model.module_defs[idx]['type'] == 'ds_conv') :
|
||
compact_bn, loose_bn = compact_CBL[2], loose_CBL[2]
|
||
else:
|
||
compact_bn, loose_bn = compact_CBL[1], loose_CBL[1]
|
||
# if idx==14:
|
||
# print("compact_bn.weight.data:",compact_bn.weight.data,"--len",len(compact_bn.weight.data))
|
||
# print("loose_bn.weight.data:",loose_bn.weight.data,"--len",len(loose_bn.weight.data))
|
||
# print("loose_bn.weight.data[out_channel_idx]:", loose_bn.weight.data[out_channel_idx],"--len",len(loose_bn.weight.data[out_channel_idx]))
|
||
compact_bn.weight.data = loose_bn.weight.data[out_channel_idx].clone()
|
||
compact_bn.bias.data = loose_bn.bias.data[out_channel_idx].clone()
|
||
compact_bn.running_mean.data = loose_bn.running_mean.data[out_channel_idx].clone()
|
||
compact_bn.running_var.data = loose_bn.running_var.data[out_channel_idx].clone()
|
||
|
||
# 获取剪枝后的模型当前卷积层的权重,这和上一个卷积层的剪枝情况有关
|
||
# print("idx:",idx)
|
||
input_mask = get_input_mask(loose_model.module_defs, idx, CBLidx2mask)
|
||
in_channel_idx = np.argwhere(input_mask)[:, 0].tolist()
|
||
# if (idx == 14):
|
||
# print("idx:",idx,"--in_channel_idx:", in_channel_idx, "--len:",len(in_channel_idx))
|
||
# print("idx:",idx,"--in_channel_idx:",len(in_channel_idx))
|
||
if(compact_model.module_defs[idx]['type'] == 'ds_conv') :
|
||
compact_conv, loose_conv = compact_CBL[0], loose_CBL[0]
|
||
# compact_conv, loose_conv = compact_CBL[1], loose_CBL[1]
|
||
compact_conv2, loose_conv2 = compact_CBL[1], loose_CBL[1]
|
||
|
||
# 拷贝权重到剪枝后的模型中去
|
||
# print("loose_conv2.weight.data:",loose_conv2.weight.data,"--len:",len(loose_conv2.weight.data))
|
||
# print("loose_conv2.weight.data[:, in_channel_idx, :, :]:",loose_conv2.weight.data[:, in_channel_idx, :, :],"--len:",len(loose_conv2.weight.data[:, in_channel_idx, :, :]))
|
||
# print("compact_conv.weight.data:",compact_conv.weight.data,"--len:",len(compact_conv.weight.data))
|
||
# print("compact_conv2.weight.data:",compact_conv2.weight.data,"--len:",len(compact_conv2.weight.data))
|
||
# print("loose_conv.weight.data:",loose_conv.weight.data,"--len:",len(loose_conv.weight.data))
|
||
# print("loose_conv.weight.data[:, in_channel_idx, :, :]:",loose_conv.weight.data[:, in_channel_idx, :, :],"--len:",len(loose_conv.weight.data[:, in_channel_idx, :, :]))
|
||
# tmp1 = loose_conv.weight.data[:, in_channel_idx, :, :].clone()
|
||
# compact_conv.weight.data = tmp1[in_channel_idx, :, :, :].clone()
|
||
compact_conv.weight.data = loose_conv.weight.data.clone()
|
||
compact_conv.bias.data = loose_conv.bias.data.clone()
|
||
# tmp2 = loose_conv2.weight.data[:, in_channel_idx, :, :].clone()
|
||
# compact_conv2.weight.data = tmp2[out_channel_idx, :, :, :].clone()
|
||
compact_conv2.weight.data = loose_conv2.weight.data.clone()
|
||
compact_conv2.bias.data = loose_conv2.bias.data.clone()
|
||
else:
|
||
compact_conv, loose_conv = compact_CBL[0], loose_CBL[0]
|
||
# print("idx:",idx,"--compact_CBL[0]:",compact_CBL[0],"--loose_CBL[0]",loose_CBL[0]) #还有一个可能就是ds_conv中有两个卷积层都需要复制权重!
|
||
# print("idx:",idx,"--compact_CBL[1]:",compact_CBL[1],"--loose_CBL[0]",loose_CBL[1])
|
||
|
||
# 拷贝权重到剪枝后的模型中去
|
||
tmp = loose_conv.weight.data[:, in_channel_idx, :, :].clone()
|
||
compact_conv.weight.data = tmp[out_channel_idx, :, :, :].clone()
|
||
|
||
for idx in Conv_idx:
|
||
|
||
compact_conv = compact_model.module_list[idx][0]
|
||
loose_conv = loose_model.module_list[idx][0]
|
||
|
||
# 虽然当前层是不带BN的卷积层,但仍然和上一个层的剪枝情况是相关的
|
||
input_mask = get_input_mask(loose_model.module_defs, idx, CBLidx2mask)
|
||
in_channel_idx = np.argwhere(input_mask)[:, 0].tolist()
|
||
|
||
# 拷贝权重到剪枝后的模型中去
|
||
compact_conv.weight.data = loose_conv.weight.data[:, in_channel_idx, :, :].clone()
|
||
compact_conv.bias.data = loose_conv.bias.data.clone()
|
||
|
||
|
||
def prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask):
|
||
# 先拷贝一份原始的模型参数
|
||
pruned_model = deepcopy(model)
|
||
# 对需要剪枝的层分别处理
|
||
for idx in prune_idx:
|
||
# 需要保留的通道
|
||
mask = torch.from_numpy(CBLidx2mask[idx]).cuda()
|
||
# 获取BN层的gamma参数,即BN层的权重
|
||
bn_module = pruned_model.module_list[idx][1]
|
||
# print("bn_module:", pruned_model.module_list[idx][1])
|
||
|
||
bn_module.weight.data.mul_(mask)
|
||
# 获取保留下来的通道产生的激活值,注意是每个通道分别获取的
|
||
activation = F.leaky_relu((1 - mask) * bn_module.bias.data, 0.1)
|
||
|
||
# 两个上采样层前的卷积层
|
||
next_idx_list = [idx + 1]
|
||
if idx == 79:
|
||
next_idx_list.append(84)
|
||
elif idx == 91:
|
||
next_idx_list.append(96)
|
||
|
||
# print("idx:",idx,"--next_idx_list:",next_idx_list)
|
||
|
||
# 对下一层进行处理
|
||
for next_idx in next_idx_list:
|
||
# 当前层的BN剪枝之后会对下一个卷积层造成影响
|
||
next_conv = pruned_model.module_list[next_idx][0]
|
||
# dim=(2,3)即在(w,h)维度上进行求和,因为是通道剪枝,一个通道对应着(w,h)这个矩形
|
||
conv_sum = next_conv.weight.data.sum(dim=(2, 3))
|
||
# 将卷积层的权重和激活值相乘获得剪枝后的每个通道的偏置,以更新下一个BN层或者下一个带偏置的卷积层的偏执(因为单独的卷积层是不会被剪枝的,所以只对偏置有影响
|
||
# print("idx:", {idx} , "| next_idx:" , {next_idx})
|
||
# print("conv_sum :", conv_sum)
|
||
# print("activation :", activation)
|
||
# print("--------------------------")
|
||
offset = conv_sum.matmul(activation.reshape(-1, 1)).reshape(-1)
|
||
if next_idx in CBL_idx:
|
||
next_bn = pruned_model.module_list[next_idx][1]
|
||
next_bn.running_mean.data.sub_(offset)
|
||
else:
|
||
next_conv.bias.data.add_(offset)
|
||
|
||
bn_module.bias.data.mul_(mask)
|
||
# 返回剪枝后的模型
|
||
return pruned_model
|
||
|
||
|
||
def obtain_bn_mask(bn_module, thre):
|
||
|
||
thre = thre.cuda()
|
||
# ge(a, b)相当于 a>= b
|
||
mask = bn_module.weight.data.abs().ge(thre).float()
|
||
# print('thre = ',thre,"| mask = ", mask)
|
||
|
||
# 返回通道是否需要剪枝的通道状态
|
||
return mask
|