Python深度学习:enumerate函数实战应用

在深度学习中,enumerate() 函数,主要用于在遍历可迭代对象(如列表、元组、字符串等)时,同时获取元素的索引和值,一般用于for循环当中。它常用于数据加载、模型训练和调试等场景,能够简化代码逻辑并提高可读性。Python 2.3. 以上版本可用,2.6 添加 start 参数。

例如

普通的for循环

i = 0
str_1 = ['hello', 'hi', 'ok', 'yes']
for elements in str_1:
    print(i, str_1[i])
    i += 1

0 hello

1 hi

2 ok

3 yes

for 循环使用 enumerat

str_1 = ['hello', 'hi', 'ok', 'yes']
for i, elements in enumerate(str_1):
    print(i, str_1[i])

0 hello

1 hi

2 ok

3 yes

以下是完整的循环代码,结合模型优化器等。

import torch
from torch.utils.data import DataLoader, TensorDataset

X = torch.randn(100, 3, 224, 224)  # 输入数据,100  3 通道的 224x224 图像
y = torch.randint(0, 10, (100,))  # 标签(0~9 的类别)

dataset = TensorDataset(X, y)
tran_d = DataLoader(dataset, batch_size=10, shuffle=True)  # 每次迭代10个数据并将数据打乱(shuffle

model = torch.nn.Linear(3 * 224 * 224, 10)   # 定义简单模型(传统神经网络),输入数据为3 * 224 * 224,输出10个类别,也可以定义卷积torch.nn.Conv2d
criterion = torch.nn.CrossEntropyLoss()       # 定义交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化器选用SGD(目前在TransformerBERT这些经典模型均使用的Adam)学习率0.01

# 训练循环
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (inputs, targets) in enumerate(tran_d):
        # 返回的 batch_idx 是当前批次的索引(从 0 开始)
        # 返回的 (inputs, targets)  DataLoader 产生的批次数据,包含输入数据和对应的标签
        # inputs:一个批次的输入数据(如形状为 [batch_size, c, h, w]
        # targets:一个批次的目标标签(如形状为 [batch_size] 的类别标签)
        # 假设是全连接网络
        inputs = inputs.view(inputs.size(0), -1)

        # 前向传播
        outputs = model(inputs)             # 输入 inputs 传递给模型 model,执行forward ,得到outputs
        loss = criterion(outputs, targets)  # 计算模型输出 outputs 与真实标签 targets 之间loss

        # 反向传播和优化
        optimizer.zero_grad()  # 梯度清零
        loss.backward()        # 反向传播,计算损失
        optimizer.step()       # 更新参数

        # 训练5batch打印一次日志(类似于打印损失,训练一定数量batch打印一次损失)
        if batch_idx % 5 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

输出 

Epoch 0, Batch 0, Loss: 2.142916440963745
Epoch 0, Batch 5, Loss: 2.3101401329040527
Epoch 1, Batch 0, Loss: 0.0
Epoch 1, Batch 5, Loss: 0.0
Epoch 2, Batch 0, Loss: 0.0
Epoch 2, Batch 5, Loss: 0.0
Epoch 3, Batch 0, Loss: 0.0
Epoch 3, Batch 5, Loss: 0.0
Epoch 4, Batch 0, Loss: 0.0
Epoch 4, Batch 5, Loss: 0.0

你可能感兴趣的:(python,深度学习,人工智能)