在智能终端设备上部署深度学习模型面临三大核心痛点:
PyTorch 2.2提供的高级量化工具链,实现关键指标提升:
指标 | 原始FP32模型 | INT8量化模型 | 提升幅度 |
---|---|---|---|
模型大小 | 400MB | 100MB | 75% |
推理延迟 | 320ms | 80ms | 75% |
能耗 | 2.5W | 0.8W | 68% |
准确率 | 85.2% | 84.7% | -0.5% |
Q ( r ) = round ( r S + Z ) Q(r) = \text{round}\left(\frac{r}{S} + Z\right) Q(r)=round(Sr+Z)
r ′ = S × ( Q ( r ) − Z ) r' = S \times (Q(r) - Z) r′=S×(Q(r)−Z)
其中:
位宽 | 动态范围 | 典型应用场景 |
---|---|---|
INT8 | [-128, 127] | 移动端、边缘设备 |
INT4 | [-8, 7] | 超轻量级模型 |
FP16 | ±65504 | GPU加速推理 |
BF16 | ±3.389e+38 | 深度学习训练 |
pip install torch torchvision torchaudio
pip install onnx onnxruntime
import torch
import torch.nn as nn
import torch.quantization as quantization
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 加载ResNet18预训练模型
model = models.resnet18(pretrained=True)
model.eval() # 设置为评估模式
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载校准数据集
calibration_dataset = datasets.ImageFolder('path/to/calibration/data', transform=transform)
calibration_loader = DataLoader(calibration_dataset, batch_size=32, shuffle=False)
# 配置量化设置
model.qconfig = quantization.get_default_qconfig('fbgemm') # 适用于ARM CPU
# 准备模型
model_prepared = quantization.prepare(model)
# 运行校准数据
with torch.no_grad():
for inputs, _ in calibration_loader:
model_prepared(inputs)
# 转换为量化模型
model_quantized = quantization.convert(model_prepared)
# 保存量化模型
torch.jit.save(torch.jit.script(model_quantized), "resnet18_quantized.pt")
# 动态量化(仅对权重和激活进行量化)
model_dynamic = quantization.quantize_dynamic(
model,
{nn.Linear, nn.Conv2d},
dtype=torch.qint8
)
# 使用更精细的量化配置
from torch.ao.quantization import QConfig, MinMaxObserver, PerChannelMinMaxObserver
qconfig = QConfig(
activation=MinMaxObserver.with_args(dtype=torch.quint8, qscheme=torch.per_tensor_affine),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
)
model.qconfig = qconfig
class FakeQuantize(torch.nn.Module):
def __init__(self, observer):
super().__init__()
self.observer = observer
self.register_buffer('scale', torch.tensor([1.0]))
self.register_buffer('zero_point', torch.tensor([0]))
def forward(self, x):
if self.training:
self.observer(x)
self.scale, self.zero_point = self.observer.calculate_qparams()
return torch.fake_quantize_per_tensor_affine(
x, self.scale, self.zero_point, 0, 255
)
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 配置QAT
model.qconfig = quantization.get_default_qat_qconfig('fbgemm')
# 准备QAT模型
model_qat = quantization.prepare_qat(model)
# 定义优化器和损失函数
optimizer = torch.optim.SGD(model_qat.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# QAT训练循环
for epoch in range(5):
model_qat.train()
for inputs, labels in train_loader:
outputs = model_qat(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 评估模型
model_qat.eval()
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in val_loader:
outputs = model_qat(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Epoch {epoch+1}, Accuracy: {100 * correct / total}%')
# 冻结BN层和量化参数
model_qat.eval()
model_quantized = quantization.convert(model_qat)
# 导出为TorchScript
scripted_model = torch.jit.script(model_quantized)
scripted_model.save("resnet18_quantized_scripted.pt")
# 导出为ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model_quantized,
dummy_input,
"resnet18_quantized.onnx",
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
// build.gradle
dependencies {
implementation 'org.pytorch:pytorch_android:2.2.0'
implementation 'org.pytorch:pytorch_android_torchvision:2.2.0'
}
// 加载模型
Module module = Module.load(assetFilePath(context, "resnet18_quantized_scripted.pt"));
// 预处理图像
Bitmap bitmap = BitmapFactory.decodeFile(imagePath);
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
);
// 模型推理
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// 后处理结果
float[] scores = outputTensor.getDataAsFloatArray();
// Package.swift
dependencies: [
.package(url: "https://github.com/pytorch/pytorch.git", .branch("release/2.2"))
]
// 加载模型
let modelPath = Bundle.main.path(forResource: "resnet18_quantized_scripted", ofType: "pt")!
let module = try! Module(path: modelPath)
// 预处理图像
let image = UIImage(contentsOfFile: imagePath)!
let inputTensor = try! Tensor(image: image, resize: CGSize(width: 224, height: 224))
// 模型推理
let output = try! module.forward(with: inputTensor)
let outputTensor = output.toTensor()
// 后处理结果
let scores = outputTensor.getDataAsFloatArray()
def evaluate(model, data_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in data_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# 评估原始模型
accuracy_fp32 = evaluate(model, test_loader)
# 评估量化模型
accuracy_int8 = evaluate(model_quantized, test_loader)
print(f'FP32 Accuracy: {accuracy_fp32:.2f}%')
print(f'INT8 Accuracy: {accuracy_int8:.2f}%')
print(f'Accuracy Drop: {accuracy_fp32 - accuracy_int8:.2f}%')
import time
def measure_latency(model, input_tensor, num_runs=100):
model.eval()
# 预热
for _ in range(10):
_ = model(input_tensor)
# 测量
start_time = time.time()
for _ in range(num_runs):
_ = model(input_tensor)
end_time = time.time()
return (end_time - start_time) / num_runs * 1000 # ms
# 测量延迟
latency_fp32 = measure_latency(model, dummy_input)
latency_int8 = measure_latency(model_quantized, dummy_input)
print(f'FP32 Latency: {latency_fp32:.2f} ms')
print(f'INT8 Latency: {latency_int8:.2f} ms')
print(f'Speedup: {latency_fp32 / latency_int8:.2f}x')
# 逐层分析量化误差
def analyze_quantization_error(model_fp32, model_int8, data_loader, num_samples=10):
model_fp32.eval()
model_int8.eval()
errors = {}
with torch.no_grad():
for i, (inputs, _) in enumerate(data_loader):
if i >= num_samples:
break
# 获取FP32输出
output_fp32 = model_fp32(inputs)
# 获取INT8输出
output_int8 = model_int8(inputs)
# 计算误差
error = torch.mean(torch.abs(output_fp32 - output_int8)).item()
# 记录误差
errors[f'sample_{i}'] = error
return errors
# 分析误差
errors = analyze_quantization_error(model, model_quantized, test_loader)
print("Quantization errors:", errors)
# 恢复敏感层为FP32
model_quantized.layer4[1].conv2 = model.layer4[1].conv2
# 使用INT4量化
from torch.ao.quantization import MinMaxObserver, QConfig
qconfig = QConfig(
activation=MinMaxObserver.with_args(dtype=torch.quint8),
weight=MinMaxObserver.with_args(dtype=torch.qint4x2)
)
model.qconfig = qconfig
model_prepared = quantization.prepare(model)
# ...
# 二值化示例
class BinaryActivation(nn.Module):
def forward(self, x):
return torch.sign(x)
class BinaryConv2d(nn.Conv2d):
def forward(self, x):
w_bin = torch.sign(self.weight)
return F.conv2d(x, w_bin, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# 使用NAS搜索量化友好架构
from nni.nas.pytorch import mutables, model_wrapper
@model_wrapper
class QuantizationFriendlyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = mutables.LayerChoice([
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.Conv2d(3, 16, kernel_size=5, padding=2)
])
# ...
评估阶段(1-2周):
量化实施阶段(2-3周):
优化阶段(1-2周):
集成与测试阶段(1-2周):
工具名称 | 功能描述 | 官网链接 |
---|---|---|
PyTorch | 深度学习框架 | https://pytorch.org/ |
ONNX | 模型格式转换工具 | https://onnx.ai/ |
TensorRT | NVIDIA高性能推理优化工具 | https://developer.nvidia.com/tensorrt |
XNNPACK | 移动端神经网络加速库 | https://github.com/google/XNNPACK |
Core ML Tools | iOS模型转换工具 | https://github.com/apple/coremltools |