tensorflow神经网络批量训练读取数据yield_Model模板

参数说明: 

传入训练数据和标签,设定batch_size数值大小,shuffle表示是否打乱顺序

def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle: 
        indices = np.arange(len(inputs)) 
        np.random.shuffle(indices)  
    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):  
        if shuffle:
            excerpt = indices[start_idx:start_idx + batch_size] 
        else:
            excerpt = slice(start_idx, start_idx + batch_size)  
        yield inputs[excerpt], targets[excerpt]  # 提取相应的样本数据和标签数据

yield解释:

 

  • yield 是一个类似 return 的关键字,只是这个函数返回的是个生成器
  • 当你调用这个函数的时候,函数内部的代码并不立马执行 ,这个函数只是返回一个生成器对象
  • 当你使用for进行迭代的时候,函数中的代码才会执行
  • 例如:
    >>> def createGenerator() :
    ...    mylist = range(3)
    ...    for i in mylist :
    ...        yield i*i
    ...
    >>> mygenerator = createGenerator() # create a generator
    >>> print(mygenerator) # mygenerator is an object!
    
    >>> for i in mygenerator:
    ...     print(i)
    0
    1
    4

     

你可能感兴趣的:(Python)