官方文档链接:https://tensorflow.google.cn/versions/r1.8/api_docs/python/tf/train/batch
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
函数功能:利用一个tensor的列表或字典来获取一个batch数据
参数介绍:
从数组中每次获取一个batch_size的数据
import numpy as np
import tensorflow as tf
def next_batch():
datasets = np.asarray(range(0,20))
input_queue = tf.train.slice_input_producer([datasets],shuffle=False,num_epochs=1)
data_batchs = tf.train.batch(input_queue,batch_size=5,num_threads=1,
capacity=20,allow_smaller_final_batch=False)
return data_batchs
if __name__ == "__main__":
data_batchs = next_batch()
sess = tf.Session()
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
while not coord.should_stop():
data = sess.run([data_batchs])
print(data)
except tf.errors.OutOfRangeError:
print("complete")
finally:
coord.request_stop()
coord.join(threads)
sess.close()
注意:tf.train.batch这个函数的实现是使用queue,queue的QueueRunner被添加到当前计算图的"QUEUE_RUNNER"集合中,所在使用初始化器的时候,需要使用tf.initialize_local_variables(),如果使用tf.global_varialbes_initialize()时,会报: Attempting to use uninitialized value