def get_layer_level_pruning_rate(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, test_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size)
def evaluator(model):
return test(model, device, criterion, val_loader)
get_data
**:加载训练集、验证集、测试集及校准集(见前文分析)。evaluator
函数**:
full_model = get_trained_model(args.model, device)
full_sized_model_result = evaluate_model(full_model, device, test_loader, args.dataset.upper())
get_trained_model
**:加载预训练模型(如ResNet18)。evaluate_model
**:测试完整模型在原始数据和触发样本上的性能(常规准确率+后门攻击成功率)。 for sparsity in [args.pruning_rate]:
print('\nCurrent desired sparsity:', sparsity, '\n')
sparsity
**:目标剪枝比例(如0.3表示剪除30%的权重)。 model = get_trained_model(args.model, device)
result = {'flops': {}, 'params': {}, 'performance':{}}
flops, params = count_flops_params(model, get_input_size(args.dataset))
result['flops']['original'] = flops
result['params']['original'] = params
count_flops_params
:计算原始模型的浮点运算量(FLOPs)和参数量。{'flops': {'original': 1.2e9}, 'params': {'original': 11.7e6}}
evaluation_result = evaluator(model)
print('Evaluation result (original model): %s' % evaluation_result)
result['performance']['original'] = evaluation_result
evaluator(model)
**:测试未剪枝模型的验证集准确率(如95.2%)。result
字典中,用于后续对比。 if args.base_algo in ['l1', 'l2']:
op_types = ['Conv2d'] # 仅剪枝卷积层
elif args.base_algo == 'level':
op_types = ['default'] # 全局剪枝(包括全连接层)
config_list = [{
'sparsity': sparsity,
'op_types': op_types
}]
base_algo
选择**:
l1
/l2
:按权重范数剪枝卷积层(通道剪枝)。level
:通用权重剪枝(细粒度剪枝)。config_list
**:指定剪枝目标(如对所有卷积层剪枝30%)。 pruner = SimulatedAnnealingPruner(
model,
config_list,
evaluator=evaluator,
base_algo=args.base_algo,
cool_down_rate=args.cool_down_rate,
experiment_data_dir=experiment_data_save_dir
)
model
:待剪枝模型。config_list
:剪枝配置。evaluator
:评估函数,用于指导剪枝策略搜索。cool_down_rate=0.9
:模拟退火算法的冷却率,控制搜索空间收缩速度。experiment_data_dir
:保存中间结果的路径。 model = pruner.compress()
compress()
**:核心剪枝操作,流程如下:
model.weight * mask
实现权重冻结(实际移除需调用ModelSpeedup
)。 evaluation_result = evaluator(model)
print('Evaluation result (masked model): %s' % evaluation_result)
result['performance']['pruned'] = evaluation_result
attack_evaluation_result = evaluate_model(model, device, test_loader, args.dataset.upper())
evaluation_result
)。attack_evaluation_result
),检测剪枝是否影响后门行为。 flops, params = count_flops_params(model, get_input_size(args.dataset))
result['flops']['pruned'] = flops
result['params']['pruned'] = params
best_config_save_path = ckpt_file_name + str(sparsity)
with open(best_config_save_path, 'wb') as f:
pickle.dump(pruner._best_config_list, f)
_best_config_list
**:各层的最优剪枝率(如{'conv1.weight': 0.2, 'conv2.weight': 0.4}
)。ModelSpeedup
)使用。ModelSpeedup
移除无效权重)。Kh × Kw × Cin × Cout × H × W
(剪枝后Cin
或Cout
减少)。graph TD
A[加载预训练模型] --> B[评估原始模型]
B --> C[初始化剪枝器]
C --> D[模拟退火搜索剪枝率]
D --> E[生成权重掩码]
E --> F[评估剪枝后模型]
F --> G[保存剪枝配置]
源码:
def get_layer_level_pruning_rate(args):
# prepare dataset
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, test_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size)
def evaluator(model):
return test(model, device, criterion, val_loader)
full_model = get_trained_model(args.model, device)
full_model.eval()
full_sized_model_result = evaluate_model(full_model, device, test_loader, args.dataset.upper())
ckpt_file_name = args.ckpt_dir_prefix + args.pretrained_model_dir
for sparsity in [args.pruning_rate]:
print('\nCurrent desired sparsity:', sparsity, '\n')
device = 'cuda'
model = get_trained_model(args.model, device)
result = {'flops': {}, 'params': {}, 'performance':{}}
flops, params = count_flops_params(model, get_input_size(args.dataset))
result['flops']['original'] = flops
result['params']['original'] = params
evaluation_result = evaluator(model)
print('Evaluation result (original model): %s' % evaluation_result)
result['performance']['original'] = evaluation_result
# module types to prune, only "Conv2d" supported for channel pruning
if args.base_algo in ['l1', 'l2']:
print(args.base_algo)
op_types = ['Conv2d']
elif args.base_algo == 'level':
op_types = ['default']
config_list = [{
'sparsity': sparsity,
'op_types': op_types
}]
experiment_data_save_dir = ckpt_file_name + '_desired_sparsity_' + str(sparsity)
pruner = SimulatedAnnealingPruner(
model, config_list, evaluator=evaluator, base_algo=args.base_algo,
cool_down_rate=args.cool_down_rate, experiment_data_dir=experiment_data_save_dir)
model = pruner.compress()
evaluation_result = evaluator(model)
print('Evaluation result (masked model): %s' % evaluation_result)
result['performance']['pruned'] = evaluation_result
attack_evaluation_result = evaluate_model(model, device, test_loader, args.dataset.upper())
flops, params = count_flops_params(model, get_input_size(args.dataset))
result['flops']['pruned'] = flops
result['params']['pruned'] = params
best_config_save_path = ckpt_file_name + str(sparsity)
print(best_config_save_path)
f = open(best_config_save_path, 'wb')
pickle.dump(pruner._best_config_list, f)
f.close()
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
非常有必要再学一下剪枝更具体的内容
目标是找到每一层神经网络的剪枝比例(例如卷积层的通道剪枝率),使得:
sparsity=0.3
表示剪除30%的权重)。SimulatedAnnealingPruner
的工作流程如下:
{'conv1.weight': 0.2, 'conv2.weight': 0.4}
),确保整体稀疏度满足目标(如30%)。T=1.0
),允许算法在早期接受较差的解(避免陷入局部最优)。while T > T_min: # 温度降至阈值时停止
# 生成新解:对当前剪枝率分布做随机扰动
new_config = current_config + random_noise
# 计算能量差(即准确率变化)
delta_E = evaluator(model_with_new_config) - evaluator(model_with_current_config)
# 决定是否接受新解
if delta_E > 0 or random() < exp(delta_E / T):
current_config = new_config # 接受更优解或一定概率接受劣解
# 降温
T *= cool_down_rate # 例如 cool_down_rate=0.9
cool_down_rate=0.9
:控制温度下降速度(值越小收敛越快)。evaluator
:验证集准确率,作为优化目标(即“能量”)。在每次生成新解时,如何选择具体剪除哪些神经元由 base_algo
决定:
base_algo='l1'
:# 示例:对卷积层权重 [C_out, C_in, K, K],按输出通道的L1和排序
importance = weight.abs().sum(dim=(1, 2, 3)) # 形状 [C_out]
threshold = np.percentile(importance, sparsity * 100)
mask = importance > threshold # 保留重要通道
base_algo='l2'
:base_algo='level'
**:在 SimulatedAnnealingPruner.compress()
中:
conv1: 20%
, conv2: 40%
),使得总稀疏度满足要求。base_algo
对每层权重排序,生成二进制掩码标记保留/剪除的权重。# 伪代码:生成掩码
for layer in model.conv_layers:
importance = calculate_importance(layer.weight, base_algo)
threshold = find_threshold(importance, layer_sparsity)
layer.mask = importance >= threshold
evaluator(model)
验证准确率,决定是否保留该剪枝方案。