本文基于
prune_training.py
文件,展示如何使用 PyTorch 对 GPT-2 Student 模型进行 L1 不规则剪枝(Unstructured Pruning),分析剪枝策略、实现代码、效果影响及保存模型的关键细节,帮助你将训练好的模型进一步轻量化。
模型训练完成后,仍存在大量“权重占位但几乎不贡献预测”的参数,剪枝可以:
本项目采用 PyTorch 自带的 torch.nn.utils.prune
模块完成剪枝。
.
├── prune_training.py
├── ../python3_distillation/gpt2_student_v2/ # 蒸馏后小模型
├── ./gpt2_student_v2_pruned/ # 剪枝后输出目录
model_path = "../python3_distillation/gpt2_student_v2"
save_path = "./gpt2_student_v2_pruned"
prune_ratio = 0.3 # 剪掉 30% 权重
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained(model_path)
此时模型为完整蒸馏后结构(如 12 层 transformer + full embedding)。
from torch.nn.utils import prune
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name="weight", amount=prune_ratio)
prune.remove(module, "weight") # 彻底置零,永久保存
说明:
操作 | 含义 |
---|---|
l1_unstructured |
按 L1 范数排序剪掉绝对值最小的参数 |
name="weight" |
仅对 Linear 层中的权重参数剪枝 |
prune.remove() |
删除掩码 mask ,将剪枝“写死”进参数矩阵 |
⚠️ 不调用 remove()
则权重仍带有 mask,保存模型后会报错。
model.save_pretrained(save_path)
输出结构为标准 Hugging Face 格式(可被推理脚本和 API 直接加载):
gpt2_student_v2_pruned/
├── config.json
├── pytorch_model.bin
├── ...
# 1. 创建输出目录
os.makedirs(save_path, exist_ok=True)
# 2. 加载模型
model = GPT2LMHeadModel.from_pretrained(model_path)
# 3. 执行 L1 剪枝 + 清除掩码
for module in model.modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name="weight", amount=0.3)
prune.remove(module, "weight")
# 4. 保存模型
model.save_pretrained(save_path)
项目 | 是否重要 | 建议说明 |
---|---|---|
是否必须 prune.remove() | 是 | 否则保存模型后加载将报错 |
是否影响结构 | 否 | 仍是 GPT2LMHeadModel 结构 |
是否影响 tokenizer | 否 | tokenizer 与模型剪枝无关 |
剪枝比例是否越高越好 | 否 | 推荐 20~50%,过高影响性能 |
torch.nn.utils.prune
,简单可靠YoanAILab 技术导航页
项目源码 × 实战部署 × 转型经验,一页总览
点击查看完整导航页
包含内容: