掌握训练进度监控是深度学习工程师的基本功。本文将带你从零开始,深入探索如何用tqdm为深度学习训练添加专业级进度条。
在深度学习训练中,我们经常面对:
传统打印语句 (print
) 的缺点:
tqdm
(阿拉伯语"进步"的缩写)是Python中最流行的进度条库:
安装命令:
pip install tqdm
from tqdm import tqdm
import time
# 最简单的进度条
for i in tqdm(range(100)):
time.sleep(0.02) # 模拟任务
输出效果:
100%|██████████| 100/100 [00:02<00:00, 49.80it/s]
from torch.utils.data import DataLoader
from tqdm import tqdm
# 创建DataLoader时设置进度条
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# 添加进度条包装
for batch in tqdm(dataloader, desc="Loading Data"):
# 数据预处理代码
pass
def train(model, dataloader, optimizer, epoch):
model.train()
total_loss = 0
# 创建进度条并设置描述
pbar = tqdm(enumerate(dataloader), total=len(dataloader),
desc=f'Epoch {epoch+1} [Train]')
for batch_idx, (data, target) in pbar:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
# 动态更新进度条信息
avg_loss = total_loss / (batch_idx + 1)
pbar.set_postfix(loss=f'{avg_loss:.4f}')
def validate(model, dataloader):
model.eval()
correct = 0
total = 0
# 禁用梯度计算以加速
with torch.no_grad():
pbar = tqdm(dataloader, desc='Validating', leave=False)
for data, target in pbar:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# 实时更新准确率
acc = 100 * correct / total
pbar.set_postfix(acc=f'{acc:.2f}%')
return 100 * correct / total
# 自定义进度条格式
pbar = tqdm(dataloader,
bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}',
ncols=100, # 控制宽度
colour='GREEN') # 设置颜色
from tqdm.auto import trange
for epoch in trange(10, desc='Epochs'):
# 外层进度条
for batch in tqdm(dataloader, desc=f'Batch', leave=False):
# 内层进度条
pass
# 确保只在主进程显示进度条
if local_rank == 0:
pbar = tqdm(total=len(dataloader))
else:
pbar = None
class TqdmLoggingHandler(logging.Handler):
def emit(self, record):
msg = self.format(record)
tqdm.write(msg)
logger = logging.getLogger()
logger.addHandler(TqdmLoggingHandler())
设置合理刷新率:
pbar = tqdm(dataloader, mininterval=0.5) # 最小刷新间隔0.5秒
避免频繁更新:
# 每10个batch更新一次
if batch_idx % 10 == 0:
pbar.update(10)
关闭非必要进度条:
# 快速迭代时禁用
pbar = tqdm(dataloader, disable=fast_mode)
from tqdm.auto import tqdm
import torch
def train_model(model, train_loader, val_loader, optimizer, epochs):
best_acc = 0
# 外层进度条(Epoch级别)
epoch_bar = tqdm(range(epochs), desc="Total Progress", position=0)
for epoch in epoch_bar:
# 训练阶段
model.train()
batch_bar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}",
position=1, leave=False)
for data, target in batch_bar:
# 训练代码...
batch_bar.set_postfix(loss=f"{loss.item():.4f}")
# 验证阶段
val_acc = validate(model, val_loader)
# 更新主进度条
epoch_bar.set_postfix(val_acc=f"{val_acc:.2f}%")
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_model.pth")
print(f"\nTraining Complete! Best Val Acc: {best_acc:.2f}%")
Q:进度条显示异常怎么办?
# 尝试设置position参数避免重叠
tqdm(..., position=0) # 外层
tqdm(..., position=1) # 内层
Q:Jupyter Notebook中不显示?
# 使用notebook专用版本
from tqdm.notebook import tqdm
Q:如何恢复中断的训练?
# 初始化时设置初始值
pbar = tqdm(total=100, initial=resume_step)
通过本文,你已经学会:
最佳实践建议:
“优秀的工具不改变算法本质,但能显著提升开发体验和效率。tqdm正是这样一把提升深度学习生产力的瑞士军刀。”
扩展阅读:
通过合理使用tqdm,你的深度学习工作流将获得专业级的进度监控能力,显著提升开发效率和训练过程的可观测性。