TensorFlow常用乘法函数总结:tf.multiply()、*、tf.matmul()、@、tf.scalar_mul()、tf.tensordot()、tf.einsum()

前两篇博客分别总结了numpy和Pytorch中常用的乘法函数:

numpy常用乘法函数总结:np.dot()、np.multiply()、*、np.matmul()、@、np.prod()、np.outer()-CSDN博客

  • 主要是 np.dot()、np.multiply()、*、np.matmul()、@ 五种,其中 np.matmul() 和 @ 完全等价,np.multiply() 和 * 在输入数据类型为 np.array 时也完全等价

Pytorch常用乘法函数总结:torch.mul()、*、torch.mm()、torch.bmm()、torch.mv()、torch.dot()、@、torch.matmul()-CSDN博客

  • torch.mul() 和 * 等价,为element-wise乘,可广播;
  • torch.mm() 二维矩阵乘法,torch.bmm() 三维批量矩阵乘法,均不可广播;
  • torch.mv(mat, vec) 为矩阵向量乘法,不可广播;
  • torch.dot() 仅支持两个一维向量点积,返回一个标量数字;
  • @ 矩阵乘法,等价于 torch.dot() + torch.mv() + torch.mm();
  • torch.matmul() 矩阵乘法,与 @ 类似,但它不止一维二维,可扩高维,可广播

本文总结TensorFlow常用的乘法函数,代码示例以TensorFlow2.x为例,其中 tf.tensordot() 这个函数的用法比较复杂;tf.einsum() 可以理解为一个通用的函数模板,在numpy和Pytorch中也有类似的函数和用法


常用

tf.multiply() 或 *【元素对位相乘】

tf.multiply(x, y) 等价于 x*y,为矩阵的element-wise乘法,要求两个矩阵的shape一致,或其中一个维度为1(扩展为另一个矩阵对应位置的维度大小) 

import tensorflow as tf

X = tf.constant([[1, 2, 3], [4, 5 ,6]], dtype=tf.float32)
Y = tf.constant([[1, 1, 1], [2, 2 ,2]], dtype=tf.float32)
Z = tf.multiply(X, Y)       # 乘法操作,对应位置元素相乘

out:
tf.Tensor(
[[ 1.  2.  3.]
 [ 8. 10. 12.]], shape=(2, 3), dtype=float32)

tf.matmul() 或 @【矩阵乘法】

tf.matmul(x, y) 等价于 x@y,为矩阵乘法,参与运算的是最后两维形成的矩阵

  • 矩阵-向量乘法不能用这个函数,会报错,应该用 tf.linalg.matvec() 
  • 两个向量的点积不能用这个函数,会报错,可以用以下两种方式:
    • tf.tensordot(a, b, axes=1) 或 tf.tensordot(a, b, axes=[0, 0])
    • tf.reduce_sum(tf.multiply(a, b))
import tensorflow as tf

X = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
Y = tf.constant([[1, 2], [1, 2], [1, 2]], dtype=tf.float32)
Z = tf.matmul(X, Y)         # 矩阵乘法操作

out:
tf.Tensor(
[[ 6. 12.]
 [15. 30.]], shape=(2, 2), dtype=float32)

不常用

tf.scalar_mul()【参数之一为标量】

标量和张量相乘(标量乘标量或向量或矩阵)

import tensorflow as tf

x = tf.constant(2, dtype=tf.float32)

Y1 = tf.constant(3, dtype=tf.float32)
Z1 = tf.scalar_mul(x, Y1)         # 标量×标量

Y2 = tf.constant([1, 2, 3], dtype=tf.float32)
Z2 = tf.scalar_mul(x, Y2)         # 标量×向量

Y3 = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
Z3 = tf.scalar_mul(x, Y3)         # 标量×矩阵

out:
tf.Tensor(6.0, shape=(), dtype=float32)
tf.Tensor([2. 4. 6.], shape=(3,), dtype=float32)
tf.Tensor(
[[ 2.  4.  6.]
 [ 8. 10. 12.]], shape=(2, 3), dtype=float32)

tf.tensordot()

参考了博客 tensorflow和num

你可能感兴趣的:(#,TensorFlow,tensorflow,pytorch,numpy)