模型剪枝(分析)

1. 函数入口与设备初始化

def get_layer_level_pruning_rate(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • 功能:检测可用设备(优先使用GPU),为后续模型加载做准备。

2. 数据加载与评估函数定义

    train_loader, val_loader, test_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size)

    def evaluator(model):
        return test(model, device, criterion, val_loader)
  • ​**get_data**:加载训练集、验证集、测试集及校准集(见前文分析)。
  • ​**evaluator函数**:
    • 输入:待评估的模型
    • 输出:模型在验证集上的准确率
    • 作用:剪枝过程中用于评估模型性能,指导剪枝策略优化。

3. 完整模型性能评估

    full_model = get_trained_model(args.model, device)
    full_sized_model_result = evaluate_model(full_model, device, test_loader, args.dataset.upper())
  • ​**get_trained_model**:加载预训练模型(如ResNet18)。
  • ​**evaluate_model**:测试完整模型在原始数据和触发样本上的性能(常规准确率+后门攻击成功率)。

4. 剪枝循环(按指定稀疏度)​

    for sparsity in [args.pruning_rate]:
        print('\nCurrent desired sparsity:', sparsity, '\n')
  • ​**sparsity**:目标剪枝比例(如0.3表示剪除30%的权重)。
  • 循环设计:支持多组稀疏度实验,此处仅测试用户指定的单一值。

5. 模型重新加载与初始化统计

        model = get_trained_model(args.model, device)
        result = {'flops': {}, 'params': {}, 'performance':{}}
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['original'] = flops
        result['params']['original'] = params
  • 重新加载模型:确保每次剪枝实验从相同的初始状态开始。
  • 统计基准指标
    • count_flops_params:计算原始模型的浮点运算量(FLOPs)和参数量。
    • 示例输出:{'flops': {'original': 1.2e9}, 'params': {'original': 11.7e6}}

6. 原始模型性能评估

        evaluation_result = evaluator(model)
        print('Evaluation result (original model): %s' % evaluation_result)
        result['performance']['original'] = evaluation_result
  • ​**evaluator(model)**:测试未剪枝模型的验证集准确率(如95.2%)。
  • 记录结果:保存到result字典中,用于后续对比。

7. 剪枝配置设置

        if args.base_algo in ['l1', 'l2']:
            op_types = ['Conv2d']  # 仅剪枝卷积层
        elif args.base_algo == 'level':
            op_types = ['default']  # 全局剪枝(包括全连接层)

        config_list = [{
            'sparsity': sparsity,
            'op_types': op_types
        }]
  • ​**base_algo选择**:
    • l1/l2:按权重范数剪枝卷积层(通道剪枝)。
    • level:通用权重剪枝(细粒度剪枝)。
  • ​**config_list**:指定剪枝目标(如对所有卷积层剪枝30%)。

8. 剪枝器初始化

        pruner = SimulatedAnnealingPruner(
            model, 
            config_list, 
            evaluator=evaluator,
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            experiment_data_dir=experiment_data_save_dir
        )
  • 参数解析
    • model:待剪枝模型。
    • config_list:剪枝配置。
    • evaluator:评估函数,用于指导剪枝策略搜索。
    • cool_down_rate=0.9:模拟退火算法的冷却率,控制搜索空间收缩速度。
    • experiment_data_dir:保存中间结果的路径。

9. 执行剪枝

        model = pruner.compress()
  • ​**compress()**:核心剪枝操作,流程如下:
    1. 搜索最优剪枝率分布
      • 使用模拟退火算法,尝试不同层的剪枝组合,保留验证集准确率最高的方案。
    2. 生成掩码(Mask)​
      • 对每个权重张量生成二进制掩码(0表示剪枝,1表示保留)。
    3. 应用掩码
      • 通过model.weight * mask实现权重冻结(实际移除需调用ModelSpeedup)。

10. 剪枝后模型评估

        evaluation_result = evaluator(model)
        print('Evaluation result (masked model): %s' % evaluation_result)
        result['performance']['pruned'] = evaluation_result

        attack_evaluation_result = evaluate_model(model, device, test_loader, args.dataset.upper())
  • 性能对比
    • 常规任务准确率(evaluation_result)。
    • 后门攻击成功率(attack_evaluation_result),检测剪枝是否影响后门行为。

11. 计算剪枝后FLOPs与参数量

        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['pruned'] = flops
        result['params']['pruned'] = params
  • 统计压缩效果
    • 示例:剪枝后FLOPs从1.2e9降至0.8e9(减少33%)。

12. 保存剪枝配置

        best_config_save_path = ckpt_file_name + str(sparsity)
        with open(best_config_save_path, 'wb') as f:
            pickle.dump(pruner._best_config_list, f)
  • ​**_best_config_list**:各层的最优剪枝率(如{'conv1.weight': 0.2, 'conv2.weight': 0.4})。
  • 保存格式:二进制pickle文件,供后续模型加速(ModelSpeedup)使用。

关键技术与原理补充

  1. 模拟退火剪枝
    • 目标:找到各层剪枝率的最优组合,最大化验证集准确率。
    • 流程
      • 随机生成剪枝率分布 → 评估性能 → 根据“温度”参数接受或拒绝新解 → 降温并重复。
  2. 掩码(Mask)机制
    • 剪枝后权重未被物理删除,而是通过掩码屏蔽(需后续ModelSpeedup移除无效权重)。
  3. FLOPs计算
    • 公式:对于卷积层,FLOPs = Kh × Kw × Cin × Cout × H × W(剪枝后CinCout减少)。

完整流程图示

graph TD
    A[加载预训练模型] --> B[评估原始模型]
    B --> C[初始化剪枝器]
    C --> D[模拟退火搜索剪枝率]
    D --> E[生成权重掩码]
    E --> F[评估剪枝后模型]
    F --> G[保存剪枝配置]

源码:

def get_layer_level_pruning_rate(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, val_loader, test_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    full_model = get_trained_model(args.model, device)
    full_model.eval()
    full_sized_model_result = evaluate_model(full_model, device, test_loader, args.dataset.upper())


    ckpt_file_name = args.ckpt_dir_prefix + args.pretrained_model_dir

    for sparsity in [args.pruning_rate]:
        print('\nCurrent desired sparsity:', sparsity, '\n')
        device = 'cuda'
        model = get_trained_model(args.model, device)
        result = {'flops': {}, 'params': {}, 'performance':{}}
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['original'] = flops
        result['params']['original'] = params

        evaluation_result = evaluator(model)
        print('Evaluation result (original model): %s' % evaluation_result)
        result['performance']['original'] = evaluation_result

        # module types to prune, only "Conv2d" supported for channel pruning
        if args.base_algo in ['l1', 'l2']:
            print(args.base_algo)
            op_types = ['Conv2d']
        elif args.base_algo == 'level':
            op_types = ['default']

        config_list = [{
            'sparsity': sparsity,
            'op_types': op_types
        }]

        experiment_data_save_dir = ckpt_file_name + '_desired_sparsity_' + str(sparsity)

        pruner = SimulatedAnnealingPruner(
            model, config_list, evaluator=evaluator, base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate, experiment_data_dir=experiment_data_save_dir)

        model = pruner.compress()
        evaluation_result = evaluator(model)
        print('Evaluation result (masked model): %s' % evaluation_result)
        result['performance']['pruned'] = evaluation_result

        attack_evaluation_result = evaluate_model(model, device, test_loader, args.dataset.upper())

        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['pruned'] = flops
        result['params']['pruned'] = params

        best_config_save_path = ckpt_file_name + str(sparsity)
        print(best_config_save_path)
        f = open(best_config_save_path, 'wb')
        pickle.dump(pruner._best_config_list, f)
        f.close()

 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

非常有必要再学一下剪枝更具体的内容

1. 剪枝目标

目标是找到每一层神经网络的剪枝比例(例如卷积层的通道剪枝率),使得:

  • 模型整体达到指定的稀疏度(如 sparsity=0.3 表示剪除30%的权重)。
  • 剪枝后的模型在验证集上的准确率尽可能高。

2. 模拟退火算法流程

SimulatedAnnealingPruner 的工作流程如下:

​**(1) 初始化**
  • 随机生成初始剪枝方案
    为每一层随机分配一个剪枝率(如 {'conv1.weight': 0.2, 'conv2.weight': 0.4}),确保整体稀疏度满足目标(如30%)。
  • 初始温度(T)​
    设置高温(如 T=1.0),允许算法在早期接受较差的解(避免陷入局部最优)。
​**(2) 迭代搜索**
while T > T_min:  # 温度降至阈值时停止
    # 生成新解:对当前剪枝率分布做随机扰动
    new_config = current_config + random_noise
    
    # 计算能量差(即准确率变化)
    delta_E = evaluator(model_with_new_config) - evaluator(model_with_current_config)
    
    # 决定是否接受新解
    if delta_E > 0 or random() < exp(delta_E / T):
        current_config = new_config  # 接受更优解或一定概率接受劣解
    
    # 降温
    T *= cool_down_rate  # 例如 cool_down_rate=0.9
  • 关键参数
    • cool_down_rate=0.9:控制温度下降速度(值越小收敛越快)。
    • evaluator:验证集准确率,作为优化目标(即“能量”)。
​**(3) 剪枝标准(Base Algo)​**

在每次生成新解时,​如何选择具体剪除哪些神经元由 base_algo 决定:

  • ​**base_algo='l1'
    按权重的 ​
    L1范数(绝对值之和)​** 排序,剪除数值最小的通道(卷积层)或权重。
    # 示例:对卷积层权重 [C_out, C_in, K, K],按输出通道的L1和排序
    importance = weight.abs().sum(dim=(1, 2, 3))  # 形状 [C_out]
    threshold = np.percentile(importance, sparsity * 100)
    mask = importance > threshold  # 保留重要通道
  • ​**base_algo='l2'
    类似L1,但使用 ​
    L2范数(平方和开根号)​** 作为重要性指标。
  • ​**base_algo='level'**:
    全局剪枝,直接按权重的绝对值排序,剪除最小的权重(不区分层)。

3. 代码中的具体实现

在 SimulatedAnnealingPruner.compress() 中:

  1. 搜索剪枝率分布
    • 通过模拟退火找到各层的最佳剪枝比例(如 conv1: 20%conv2: 40%),使得总稀疏度满足要求。
  2. 生成掩码(Mask)​
    • 根据 base_algo 对每层权重排序,生成二进制掩码标记保留/剪除的权重。
    # 伪代码:生成掩码
    for layer in model.conv_layers:
        importance = calculate_importance(layer.weight, base_algo)
        threshold = find_threshold(importance, layer_sparsity)
        layer.mask = importance >= threshold
  3. 评估与选择
    • 每次生成新掩码后,调用 evaluator(model) 验证准确率,决定是否保留该剪枝方案。

4. 为什么能保证“较优”?

  1. 模拟退火的全局搜索能力
    • 高温阶段允许接受暂时性性能下降,避免陷入局部最优。
  2. 基于重要性的剪枝准则
    • L1/L2范数假设:数值小的权重对模型贡献小,剪除后影响较低。
  3. 直接优化验证集准确率
    • 剪枝策略的优劣最终由验证集性能决定,而非启发式规则。

你可能感兴趣的:(剪枝,python,机器学习)