数据集自动分类训练集和验证集代码

针对类似于猫狗大战的数据集,拿到数据集的目录结构为:
这里写图片描述
数据集自动分类训练集和验证集代码_第1张图片

在label_00文件夹下是正样本,label_11下是负样本。此时如果想要分离验证集并且带有随机性,代码如下:

# -*- coding: utf-8 -*-
import shutil
import keras
import numpy as np
from keras.models import Sequential
from keras.initializers import he_normal
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import ModelCheckpoint
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, AveragePooling2D
import keras.backend as K
K.set_image_data_format('channels_last')
K.set_learning_phase(1)
import os
import matplotlib.pyplot as plt
#加载keras模块
from keras import optimizers
from keras.optimizers import SGD, Adam, RMSprop
import matplotlib.pyplot as plt
def get_files(file_dir):
    '''
    Args:
        file_dir: file directory
    Returns:
        list of images and labels
    '''
    zc = []
    label_zc = [] 
    xc = []
    label_xc = []

    for file in os.listdir(file_dir+'/label_00'):
        zc.append(file_dir +'/label_00'+'/'+ file)
        label_zc.append(0)
    for file in os.listdir(file_dir+'/label_11'):
        xc.append(file_dir +'/label_11'+'/'+file)
        label_xc.append(1)
#把cat和dog合起来组成一个list(img和lab)
    image_list = np.hstack((zc, xc))
    label_list = np.hstack((label_zc, label_xc))

#利用shuffle打乱顺序
    temp = np.array([image_list, label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)

#从打乱的temp中再取出list(img和lab)
    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]



    return image_list,label_list

#返回两个list 分别为图片文件名及其标签  顺序已被打乱



train_dir = 'datasets'      #训练数据文件夹
image_list,label_list = get_files(train_dir)

new_dir='datasets_new'
train_image_list=list(image_list[0:int(len(image_list)*0.7)])
#train_label_list=list(label_list[0:int(len(image_list)*0.7)])
test_image_list=list(image_list[int(len(image_list)*0.7):len(image_list)])
#test_label_list=list(label_list[int(len(image_list)*0.7):len(image_list)])
for filename in train_image_list:
    print(filename)
    if('zc'in filename):
        a=filename.split('/')
        b=a[2]
        shutil.copy(filename,'datasets_new/train_data/label_00/'+b)
    else:
        a=filename.split('/')
        b=a[2]
        shutil.copy(filename,'datasets_new/train_data/label_11/'+b)
for filename in test_image_list:
    if('zc'in filename):
        a=filename.split('/')
        b=a[2]
        shutil.copy(filename,'datasets_new/test_data/label_00/'+b)
    else:
        a=filename.split('/')
        b=a[2]
        shutil.copy(filename,'datasets_new/test_data/label_11/'+b)

最后得到的文件目录为:
数据集自动分类训练集和验证集代码_第2张图片
即可。

你可能感兴趣的:(深度学习,图像)