在我的torchvision库里介绍的博文(https://www.cnblogs.com/yjphhw/p/9773333.html)里说了对pytorch的dataset的定义方式。
本文相当于实现一个自定义的数据集,而这正是我们在做自己工程所需要的,我们总是用自己的数据嘛。
继承 from torch.utils.data import Dataset 类
然后实现 __len__(self) ,和 __getitem__(self,idx) 两个方法。以及数据增强也可以写入,数据增强想了想还是放到了Dataset里,
习惯上可能与常用的不同,但是觉得由于每种数据都有自己的增强方法所以,增强方法可以和数据集绑定到一起的。
接上一节我们通过切割,获取了2217个图像切片。
这就是我的FarmDataset
from torch.utils.data import Dataset, DataLoader
from PIL import Image,ImageEnhance
from osgeo import gdal
from torchvision import transforms
import glob
import torch as tc
import numpy as np
class FarmDataset(Dataset):
def __init__(self,istrain=True,isaug=True):
self.istrain=istrain
self.trainxformat='./data/train/data1500/*.bmp'
self.trainyformat='./data/train/label1500/*.bmp'
self.testxformat='./data/test/*.png'
self.fns=glob.glob(self.trainxformat) if istrain else glob.glob(self.testxformat)
self.length=len(self.fns)
self.transforms=transforms
self.isaug=isaug
def __len__(self):
#total length is 2217
return self.length
def __getitem__(self,idx):
if self.istrain:
imgxname=self.fns[idx]
sampleimg = Image.open(imgxname)
imgyname=imgxname.replace('data1500','label1500')
targetimg = Image.open(imgyname).convert('L')
#sampleimg.save('original.bmp')
#data augmentation
if self.isaug:
sampleimg,targetimg=self.imgtrans(sampleimg,targetimg)
#check the result of dataautmentation
#sampleimg.save('sampletmp.bmp')
#targetimg.save('targettmp.bmp')
sampleimg=transforms.ToTensor()(sampleimg)
#targetimg=transforms.ToTensor()(targetimg).squeeze(0).long()
targetimg=np.array(targetimg)
targetimg=tc.from_numpy(targetimg).long() #to tensor
#print(sampleimg.shape,targetimg.shape)
return sampleimg,targetimg
else:
return gdal.Open(self.fns[idx])
def imgtrans(self,x,y,outsize=1024):
'''input is a PIL image
image dataaugumentation
return also aPIL image。
'''
#rotate should consider y
degree=np.random.randint(360)
x=x.rotate(degree,resample=Image.NEAREST,fillcolor=0)
y=y.rotate(degree,resample=Image.NEAREST,fillcolor=0) #here should be carefull, in case of label damage
#random do the input image augmentation
if np.random.random()>0.5:
#sharpness
factor=0.5+np.random.random()
enhancer=ImageEnhance.Sharpness(x)
x=enhancer.enhance(factor)
if np.random.random()>0.5:
#color augument
factor=0.5+np.random.random()
enhancer=ImageEnhance.Color(x)
x=enhancer.enhance(factor)
if np.random.random()>0.5:
#contrast augument
factor=0.5+np.random.random()
enhancer=ImageEnhance.Contrast(x)
x=enhancer.enhance(factor)
if np.random.random()>0.5:
#brightness
factor=0.5+np.random.random()
enhancer=ImageEnhance.Brightness(x)
x=enhancer.enhance(factor)
#img flip
transtypes=[Image.FLIP_LEFT_RIGHT,Image.FLIP_TOP_BOTTOM,
Image.ROTATE_90,Image.ROTATE_180,Image.ROTATE_270]
transtype=transtypes[np.random.randint(len(transtypes))]
x = x.transpose(transtype)
y = y.transpose(transtype)
#img resize between 0.8-1.2
w,h=x.size
factor=1+np.random.normal()/5
if factor>1.2: factor=1.2
if factor<0.8: factor=0.8
#print(factor,x.size)
x=x.resize((int(w*factor),int(h*factor)),Image.NEAREST)
y=y.resize((int(w*factor),int(h*factor)),Image.NEAREST)
#random crop
w,h=x.size
stx=np.random.randint(w-outsize)
sty=np.random.randint(h-outsize)
#print((stx,sty,outsize,outsize))
x=x.crop((stx,sty,stx+outsize,sty+outsize)) #stx,sty,width,height
y=y.crop((stx,sty,stx+outsize,sty+outsize))
#print(x.size,y.size)
return x,y #return outsized pil image
if __name__=='__main__':
d=FarmDataset(istrain=True)
x,y=d[2216]
print(x.shape)
print(y.shape)
输入的是个1500x1500的图像,输出的是增强后的1024x1024后的图像。
其实对于分割问题来看,以后这个就可以作为一个模板,修改修改就可以换到另一个数据集中。
放几张图片:
原始图像:
进行数据增强后可以得到的一系列:
经过check 发现没有的问题通过测试。