之前用到了tf.slice,当时会用,现在又忘了,故重温一下,直接看示例
import tensorflow as tf
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
sess = tf.Session()
data = tf.slice(t, [1,0,0], [1,1,3])
print(sess.run(data))
data = tf.slice(t, [1, 0, 0], [2, 2, 3])
print(sess.run(data))
tf.slice(inputs, begin, size, name)包含几个参数,一般只需要设置前三个即可。
inputs是你要切割的tensor,begin是切割的起始位置,size是要切出来的大小。
对应上面的示例,我的理解是对于三维矩阵,第一维表示batch,第二位表示行,第三维表示列(index都是从0开始)。
具体点t是一个三维矩阵,begin是[1,0,0],也就是在从第2个batch中第1行第1列的位置开始切割,切割的大小是多少呢?再看size是[1,1,3],三个数字分别表示一整个batch、一行、三列,由于起始位置已经固定,所以切割出来的应该是第2个batch中第一行中的三列,正是[3,3,3]。