MNIST进阶之next_batch()

对于训练数据分batch十分有必要,但是MNIST官网给出的程序采用了封装函数,本文写了一个next_batch()适用于自己制作的数据集。

import numpy as np


class DataSet(object):

    def __init__(self, images, labels, num_examples):
        self._images = images
        self._labels = labels
        self._epochs_completed = 0  # 完成遍历轮数
        self._index_in_epochs = 0   # 调用next_batch()函数后记住上一次位置
        self._num_examples = num_examples  # 训练样本数

    def next_batch(self, batch_size, fake_data=False, shuffle=True):
        start = self._index_in_epochs

        if self._epochs_completed == 0 and start == 0 and shuffle:
            index0 = np.arange(self._num_examples)
            print(index0)
            np.random.shuffle(index0)
            print(index0)
            self._images = np.array(self._images)[index0]
            self._labels = np.array(self._labels)[index0]
            print(self._images)
            print(self._labels)
            print("-----------------")

        if start + batch_size > self._num_examples:
            self._epochs_completed += 1
            rest_num_examples = self._num_examples - start
            images_rest_part = self._images[start:self._num_examples]
            labels_rest_part = self._labels[start:self._num_examples]
            if shuffle:
                index = np.arange(self._num_examples)
                np.random.shuffle(index)
                self._images = self._images[index]
                self._labels = self._labels[index]
            start = 0
            self._index_in_epochs = batch_size - rest_num_examples
            end = self._index_in_epochs
            images_new_part = self._images[start:end]
            labels_new_part = self._labels[start:end]
            return np.concatenate((images_rest_part, images_new_part), axis=0), np.concatenate(
                (labels_rest_part, labels_new_part), axis=0)

        else:
            self._index_in_epochs += batch_size
            end = self._index_in_epochs
            return self._images[start:end], self._labels[start:end]


if __name__ == '__main__':
    input = ['a', 'b', '1', '2', '*', '3', 'c', '&', '#']
    output = ["Letter", "Letter", "Number", "Number", "Symbol", "Number", "Letter", "Symbol", "Symbol"]
    ds = DataSet(input, output, 9)
    for i in range(3):
        image_batch, label_batch = ds.next_batch(4)
        print(image_batch)
        print(label_batch)

现在来看结果:

MNIST进阶之next_batch()_第1张图片

 

你可能感兴趣的:(DeepLearning)