Tqdm是Python自带的进度条库,借助它可以很方便展示深度学习训练过程的更新信息。支持滚动更新每步训练的损失、准确率,每一步的耗时,以及剩余时间。
代码如下:
import time
import numpy as np
import torch
from tqdm import tqdm
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True,drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=256)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Model()
model.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def train_one_epoch(epoch, model, train_loader, test_loader):
# 训练模式
correct = 0
total = 0
sum_loss = 0
model.train()
loop = tqdm(train_loader, desc='Train')
for x, y in loop:
x, y = x.to(device), y.to(device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_ = torch.argmax(y_pred, dim=1)
correct += (y_ == y).sum().item()
total += y.size(0)
sum_loss += loss.item()
running_loss = sum_loss / total
running_acc = correct / total
# 更新训练信息
loop.set_description(f'Epoch [{epoch+1}/{epochs}]')
loop.set_postfix(loss = running_loss, acc = running_acc)
epoch_loss = sum_loss / total
epoch_acc = correct / total
# 测试模式
test_correct = 0
test_total = 0
test_sum_loss = 0
model.eval()
with torch.no_grad():
loop2 = tqdm(test_loader, desc='Test')
for x, y in loop2:
x, y = x.to(device), y.to(device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_ = torch.argmax(y_pred, dim=1)
test_correct += (y_ == y).sum().item()
test_total += y.size(0)
test_sum_loss += loss.item()
test_running_loss = test_sum_loss / test_total
test_running_acc = test_correct / test_total
# 更新测试信息
loop2.set_postfix(loss = test_running_loss, acc = test_running_acc)
test_epoch_loss = test_sum_loss / test_total
test_epoch_acc = test_correct / test_total
return epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc
# 训练
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc = train_one_epoch(epoch, model, train_dl, test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(test_epoch_loss)
test_acc.append(test_epoch_acc)
Epoch [1/10]: 100%|██████████| 937/937 [00:08<00:00, 110.80it/s, acc=0.815, loss=0.00795]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.80it/s, acc=0.864, loss=0.00149]
Epoch [2/10]: 100%|██████████| 937/937 [00:08<00:00, 109.28it/s, acc=0.879, loss=0.00512]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.69it/s, acc=0.871, loss=0.00138]
Epoch [3/10]: 100%|██████████| 937/937 [00:08<00:00, 107.06it/s, acc=0.899, loss=0.00433]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.56it/s, acc=0.897, loss=0.00112]
Epoch [4/10]: 100%|██████████| 937/937 [00:08<00:00, 106.62it/s, acc=0.91, loss=0.00382]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.62it/s, acc=0.89, loss=0.00116]
Epoch [5/10]: 100%|██████████| 937/937 [00:08<00:00, 104.30it/s, acc=0.918, loss=0.00343]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.77it/s, acc=0.903, loss=0.00105]
Epoch [6/10]: 100%|██████████| 937/937 [00:08<00:00, 106.34it/s, acc=0.927, loss=0.00304]
Test: 100%|██████████| 40/40 [00:00<00:00, 44.41it/s, acc=0.904, loss=0.00103]
Epoch [7/10]: 100%|██████████| 937/937 [00:08<00:00, 105.92it/s, acc=0.935, loss=0.00273]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.66it/s, acc=0.911, loss=0.00098]
Epoch [8/10]: 100%|██████████| 937/937 [00:09<00:00, 102.60it/s, acc=0.941, loss=0.00245]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.58it/s, acc=0.908, loss=0.00105]
Epoch [9/10]: 100%|██████████| 937/937 [00:08<00:00, 107.29it/s, acc=0.948, loss=0.00219]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.58it/s, acc=0.914, loss=0.00103]
Epoch [10/10]: 100%|██████████| 937/937 [00:08<00:00, 109.47it/s, acc=0.954, loss=0.00195]
Test: 100%|██████████| 40/40 [00:00<00:00, 42.52it/s, acc=0.919, loss=0.00102]