torch.mm
是 PyTorch 中用于 二维矩阵乘法(matrix-matrix multiplication) 的函数,等价于数学中的 A × B
矩阵乘积。
torch.mm(input, mat2) → Tensor
执行的是两个 2D Tensor(矩阵)的标准矩阵乘法。
input
: 第一个二维张量,形状为 (n × m)
mat2
: 第二个二维张量,形状为 (m × p)
(n × p)
的张量条件 | 说明 |
---|---|
仅支持 2D 张量 | 一维或三维以上使用 torch.matmul 或 @ 操作符 |
维度要匹配 | 即 input.shape[1] == mat2.shape[0] |
不支持广播 | 两个矩阵维度不匹配会直接报错 |
结果是普通矩阵乘积 | 不是逐元素乘法(Hadamard),即不是 * 或 torch.mul() |
import torch
A = torch.tensor([[1., 2.], [3., 4.]]) # 2x2
B = torch.tensor([[5., 6.], [7., 8.]]) # 2x2
C = torch.mm(A, B)
print(C)
输出:
tensor([[19., 22.],
[43., 50.]])
计算步骤:
C[0][0] = 1*5 + 2*7 = 19
C[0][1] = 1*6 + 2*8 = 22
...
A = torch.rand(2, 3)
B = torch.rand(4, 2)
C = torch.mm(A, B) # ❌ 会报错
报错:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)
@
或 matmul
)A = torch.rand(3, 4)
B = torch.rand(4, 5)
C1 = torch.mm(A, B)
C2 = A @ B # 推荐用法
C3 = torch.matmul(A, B) # 推荐用法
函数名 | 支持维度 | 运算类型 | 支持广播 |
---|---|---|---|
torch.mm |
仅限二维 | 矩阵乘法 | ❌ 不支持 |
torch.matmul |
1D, 2D, ND | 自动判断点乘 / 矩阵乘 | ✅ 支持 |
torch.bmm |
批量二维乘法 | 3D Tensor batch × batch | ❌ 不支持 |
torch.mul |
任意维度 | 元素乘(Hadamard) | ✅ 支持 |
* 运算符 |
任意维度 | 元素乘 | ✅ 支持 |
@ 运算符 |
ND(推荐用) | 矩阵乘法(和 matmul 一样) | ✅ |
output = torch.mm(W, x)
x' = torch.mm(R, x) + t
mm
?使用场景 | 用什么 |
---|---|
仅二维矩阵乘法 | torch.mm |
高维或支持广播乘法 | torch.matmul / @ |
批量矩阵乘法 (如 batch_size×3×3) | torch.bmm |
元素乘 | torch.mul or * |
在 PyTorch 中,torch.bmm
是 批量矩阵乘法(batch matrix multiplication) 的操作,专用于处理三维张量(batch of matrices)。它的主要作用是对一组矩阵成对进行乘法,效率远高于手动循环计算。
torch.bmm
语法torch.bmm(input, mat2, *, out=None) → Tensor
Tensor
,形状为 (B, N, M)
Tensor
,形状为 (B, M, P)
(B, N, P)
这表示对 B
对 N×M
和 M×P
的矩阵进行成对相乘。
import torch
# 定义两个 batch 矩阵
A = torch.randn(4, 2, 3) # shape: (B=4, N=2, M=3)
B = torch.randn(4, 3, 5) # shape: (B=4, M=3, P=5)
# 批量矩阵乘法
C = torch.bmm(A, B) # shape: (4, 2, 5)
print(C.shape) # 输出: torch.Size([4, 2, 5])
# 慢速手动方式
C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))])
# 等效于 bmm
C_bmm = torch.bmm(A, B)
print(torch.allclose(C_manual, C_bmm)) # True
RuntimeError: batch1 must be a 3D tensor
你可以通过 .unsqueeze()
手动调整维度:
a = torch.randn(2, 3)
b = torch.randn(3, 4)
# 升维
a_batch = a.unsqueeze(0) # (1, 2, 3)
b_batch = b.unsqueeze(0) # (1, 3, 4)
c = torch.bmm(a_batch, b_batch) # (1, 2, 4)
(B, N, M)
× (B, M, P)
→ (B, N, P)
M
不一致会报错:RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension of batch1 tensor
bmm
不支持广播(broadcasting).expand()
:A = torch.randn(1, 2, 3) # 单个矩阵
B = torch.randn(4, 3, 5) # 4 个矩阵
# 扩展 A 以进行 batch 乘法
A_expand = A.expand(4, -1, -1)
C = torch.bmm(A_expand, B) # (4, 2, 5)
# 假设有 B 个旋转矩阵和点坐标
R = torch.randn(B, 3, 3) # 旋转矩阵
points = torch.randn(B, 3, N) # 点云
# 先转置点坐标为 (B, N, 3)
points_T = points.transpose(1, 2) # (B, N, 3)
# 用 bmm 做点变换:每组点乘旋转
transformed = torch.bmm(points_T, R.transpose(1, 2)) # (B, N, 3)
特性 | torch.bmm |
---|---|
操作对象 | 三维张量(batch of matrices) |
核心规则 | (B, N, M) x (B, M, P) = (B, N, P) |
是否支持广播 | ❌ 不支持,需要手动 .expand() |
与 matmul 区别 |
matmul 支持更多广播,bmm 更高效用于纯批量矩阵乘法 |
应用场景 | 批量线性变换、点云配准、神经网络前向传播等 |