pytorch学习(Dataset和Tensorboard)(二)

pytorch加载数据

主要涉及到两个类:
Dataset提供一种方法去获取数据和label。

  • 如何获取每一个数据及其label。
  • 告诉我们一共有多少数据。

Dataloader为后面的网络提供不同的数据形式。

举例

如数据集:https://download.pytorch.org/tutorial/hymenoptera_data.zip
数据集

  • 在这里插入图片描述
    • pytorch学习(Dataset和Tensorboard)(二)_第1张图片
    • pytorch学习(Dataset和Tensorboard)(二)_第2张图片
from torch.utils.data import Dataset #需要重写__getitem__方法和选择重写__len__方法。
import PIL import Image
import os
img_path = "D:\\PycharmProjects\\pytorch\\dataset\\train\\ants\\0013035.jpg"
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir.path) #这里是一个list,list里面是所有蚂蚁图片的名称(str类型)

class MyData(Dataset):
	def __init__(self, root_dir, label_dir): #root_dir = "dataset/train",label_dir = "ants"
		self.root_dir = root_dir
		self.label_dir = label_dir
		self.img_path_list = os.listdir(self_path) #这里是一个list,list里面是所有蚂蚁图片的名称(str类型)
	def __getitem__(self,idx):
		img_name = self.img_path_list[idx]
		img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
		img = Image.open(img_item_path )
		label = self.label_dir
		return img, label
	def __len__():
	return len(self.img_path_list)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
img, label = ants_dataset[0]#获取蚂蚁数据集的第一个数据和label

bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + ants_dataset #可以直接相加整合

TensorBoard的使用

TensorBoard主要用来显示训练过程的一些结果,如训练过程loss是如何变化的,图像是如何变化的。。
先到需要的环境下安装该包,pip install tensorboard

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs") #创建logs文件夹,后续的事件文件都会输入到该文件夹内。
# writer.add_image()
# 例:y = x
for i in range(100):
	writer.add_scalar("y = x", i, i)
# 例 y = 2x
for i in range(100):
	writer.add_scalar("y = 2x", 2*i, i)
for i in range(100):
	writer.add_scalar("y = 2x", 3*i, i) #如果名字不改的话,数据都会写入"y = 2x"这幅图中。有后续的
                                        #数据要写入的话,可以删掉原文件,或者新建文件夹存起来
writer.close()

然后在terminal中输入:

tensorboard --logdir=logs --port=6007 #logs为存放事件的文件夹,并且指定端口

进入到:http://localhost:6007即可,结果如下:
pytorch学习(Dataset和Tensorboard)(二)_第3张图片

你可能感兴趣的:(pytorch,深度学习,机器学习)