深度学习Pytorch通过tqdm实现进度条打印训练过程信息

Tqdm是Python自带的进度条库,借助它可以很方便展示深度学习训练过程的更新信息。支持滚动更新每步训练的损失、准确率,每一步的耗时,以及剩余时间。

代码如下:

import time
import numpy as np
import torch
from tqdm import tqdm

数据准备

train_dl = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True,drop_last=True)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=256)

创建模型

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Model()
model.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

定义训练一次的函数

def train_one_epoch(epoch, model, train_loader, test_loader):

    # 训练模式
    correct = 0
    total = 0
    sum_loss = 0
    
    model.train()
    loop = tqdm(train_loader, desc='Train')
    for x, y in loop:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            y_ = torch.argmax(y_pred, dim=1)
            correct += (y_ == y).sum().item()
            total += y.size(0)
            sum_loss += loss.item()
            running_loss = sum_loss / total
            running_acc = correct / total
            
        # 更新训练信息
        loop.set_description(f'Epoch [{epoch+1}/{epochs}]')
        loop.set_postfix(loss = running_loss, acc = running_acc)
    
    epoch_loss = sum_loss / total
    epoch_acc = correct / total
    
    # 测试模式
    test_correct = 0
    test_total = 0
    test_sum_loss = 0
    
    model.eval()
    with torch.no_grad():

        loop2 = tqdm(test_loader, desc='Test')
        for x, y in loop2:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_ = torch.argmax(y_pred, dim=1)
            test_correct += (y_ == y).sum().item()
            test_total += y.size(0)
            test_sum_loss += loss.item()
            test_running_loss = test_sum_loss / test_total
            test_running_acc = test_correct / test_total
            
            # 更新测试信息
            loop2.set_postfix(loss = test_running_loss, acc = test_running_acc)
            
    test_epoch_loss = test_sum_loss / test_total
    test_epoch_acc = test_correct / test_total
    
    return epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc

训练多次

# 训练
epochs = 10
train_loss = []
train_acc  = []
test_loss = []
test_acc  = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc = train_one_epoch(epoch, model, train_dl, test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(test_epoch_loss)
    test_acc.append(test_epoch_acc)
Epoch [1/10]: 100%|██████████| 937/937 [00:08<00:00, 110.80it/s, acc=0.815, loss=0.00795]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.80it/s, acc=0.864, loss=0.00149]
Epoch [2/10]: 100%|██████████| 937/937 [00:08<00:00, 109.28it/s, acc=0.879, loss=0.00512]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.69it/s, acc=0.871, loss=0.00138]
Epoch [3/10]: 100%|██████████| 937/937 [00:08<00:00, 107.06it/s, acc=0.899, loss=0.00433]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.56it/s, acc=0.897, loss=0.00112]
Epoch [4/10]: 100%|██████████| 937/937 [00:08<00:00, 106.62it/s, acc=0.91, loss=0.00382] 
Test: 100%|██████████| 40/40 [00:00<00:00, 43.62it/s, acc=0.89, loss=0.00116] 
Epoch [5/10]: 100%|██████████| 937/937 [00:08<00:00, 104.30it/s, acc=0.918, loss=0.00343]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.77it/s, acc=0.903, loss=0.00105]
Epoch [6/10]: 100%|██████████| 937/937 [00:08<00:00, 106.34it/s, acc=0.927, loss=0.00304]
Test: 100%|██████████| 40/40 [00:00<00:00, 44.41it/s, acc=0.904, loss=0.00103]
Epoch [7/10]: 100%|██████████| 937/937 [00:08<00:00, 105.92it/s, acc=0.935, loss=0.00273]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.66it/s, acc=0.911, loss=0.00098] 
Epoch [8/10]: 100%|██████████| 937/937 [00:09<00:00, 102.60it/s, acc=0.941, loss=0.00245]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.58it/s, acc=0.908, loss=0.00105]
Epoch [9/10]: 100%|██████████| 937/937 [00:08<00:00, 107.29it/s, acc=0.948, loss=0.00219]
Test: 100%|██████████| 40/40 [00:00<00:00, 43.58it/s, acc=0.914, loss=0.00103] 
Epoch [10/10]: 100%|██████████| 937/937 [00:08<00:00, 109.47it/s, acc=0.954, loss=0.00195]
Test: 100%|██████████| 40/40 [00:00<00:00, 42.52it/s, acc=0.919, loss=0.00102] 

你可能感兴趣的:(pytorch,深度学习,python)