针对tensorflow2.0
(1)tensor的拼接(tf.concat)
In [5]: a = tf.ones([4,28,28,3])
In [6]: b = tf.ones([2,28,28,3])
In [7]: c = tf.concat([a,b],axis=0)
In [8]: c.shape
Out[8]: TensorShape([6, 28, 28, 3])
具体用法为:c = tf.concat([a,b],axis=0),axis代表轴的序号。需要注意的是:a、b中,需要拼接的维度可以不等,但是其他的维度必须相等。
(2)tensor的堆叠(tf.stack)
堆叠的意思是将两个数据合并并且扩展出一个轴来。
In [9]: a = tf.ones([28,28,3])
In [10]: b = tf.ones([28,28,3])
In [11]: c = tf.stack([a,b],axis=0)
In [12]: c.shape
Out[12]: TensorShape([2, 28, 28, 3])
具体做法是:c = tf.stack([a,b],axis=0),axis代表扩展出的轴所代表的序号。需要注意的是:a、b中,所有的维度必须相等。
(1)tf.unstack(与tf.stack效果相反,指定轴全都打散为1)
In [12]: c.shape
Out[12]: TensorShape([2, 28, 28, 3])
In [13]: aa,bb = tf.unstack(c,axis=0)
In [14]: aa.shape,bb.shape
Out[14]: (TensorShape([28, 28, 3]), TensorShape([28, 28, 3]))
In [15]: res = tf.unstack(c,axis=3)
In [16]: res[0].shape,res[1].shape,res[2].shape
Out[16]: (TensorShape([2, 28, 28]), TensorShape([2, 28, 28]), TensorShape([2, 28, 28]))
具体做法是:res = tf.unstack(c,axis=3),axis代表打散的轴所代表的序号,res存储输出的结果。
(2)tf.split
上面的tf.unstack函数是将指定轴全都打散为1,而tf.split更加灵活,它可以打散成所需要的各个维度。
In [12]: c.shape
Out[12]: TensorShape([2, 28, 28, 3])
In [19]: res = tf.split(c,axis=2,num_or_size_splits=[14,5,9])
In [20]: res[0].shape,res[1].shape,res[2].shape
Out[20]:
(TensorShape([2, 28, 14, 3]),
TensorShape([2, 28, 5, 3]),
TensorShape([2, 28, 9, 3]))
具体做法是:res = tf.split(c,axis=2,num_or_size_splits=[14,5,9]),axis代表打散的轴所代表的序号;num_or_size_splits代表打散分为几个tensor,以及每个tensor占几维;res存储输出的结果。