5. pytorch-项目流程

1. 数据准备

基本步骤: 生成Dataset(或其子类)对象->传入DataLoader(为可迭代对象,可以用for迭代)

1.1 Dataset类

Dataset为抽象类

  • 注意
    • 直接从Dataset中取出的数据是没有经过transform的,只有通过Dataloader加载才可以
      training_data=torchvision.datasets.MNIST(root="./mnist", train=True,
                                 transform=torchvision.transforms.ToTensor(), download=True)
      # 像素点的范围仍然是0-255, 不是0-1
      print(training_data.train_data[0])
      

1.1.1 Dataset子类:TensorDataset

  • 源码阅读
    class TensorDataset(Dataset):
      """Dataset wrapping tensors.
    
      Each sample will be retrieved by indexing tensors along the first dimension.
    
      Arguments:
          *tensors (Tensor): tensors that have the same size of the first dimension.
      代码示例:
          x = torch.linspace(1, 10, 10)
          y = torch.linspace(10, 1, 10)
          dataset = TensorDataset(x, y)
      """
      def __init__(self, *tensors):
          """
          &1
          tensors[0]为x; tensor[1]为y。因为x,y的batch_size要相同,所以要assert
          TensorDataset(x, y, z...)传入任意多参数都是可以的
          """
          assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
          self.tensors = tensors
          
    
      def __getitem__(self, index):
          """
          &2
          相当于重载[]运算符
          """
          return tuple(tensor[index] for tensor in self.tensors)
    
      def __len__(self):
          return self.tensors[0].size(0)
    
  • 示例代码
    import torch
    import torch.utils.data as Data
    
    if __name__ == "__main__":
      x = torch.linspace(1, 10, 10)
      y = torch.linspace(10, 1, 10)
      dataset = Data.TensorDataset(x, y)
      # &1
      # 当最后一个step不足5个(假设仅剩2个),则仅会返回2个
      # shuffle: 训练时为True则打乱数据集
      # num_workers为子进程数量
      dataloader = Data.DataLoader(dataset=dataset, batch_size=5,
                                   shuffle=True, num_workers=2)
      for epoch in range(3):
          for step,input_data in enumerate(dataloader):
              print(f"{epoch}-{step}:\n{input_data}")
    

2. 网络搭建

2.1 class模式

2.2 Sequential模式

net = torch.nn.Sequential(
        torch.nn.Linear(2, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 2)
    )

你可能感兴趣的:(5. pytorch-项目流程)