pytorch小记(二十六):全面解读 PyTorch 的 `torch.matmul`

pytorch小记(二十六):全面解读 PyTorch 的 `torch.matmul`

  • PyTorch 中的 `torch.matmul` 详解与使用指南
    • 一、什么是 `torch.matmul`
    • 二、基本用法示例
      • 1. 向量点积(1-D × 1-D)
      • 2. 二维矩阵乘法(2-D × 2-D)
      • 3. 批量矩阵乘法(≥3-D)
      • 4. 向量与矩阵混合
    • 三、与 `mm`、`bmm` 的区别
    • 四、性能与数值稳定性
    • 五、典型应用场景
    • 六、注意事项
    • 七、总结


在深度学习与张量运算中,维度管理是一个常见且容易出错的环节。PyTorch 提供的 unsqueeze 操作,可以让你在张量的任意位置插入一个长度为 1 的新维度,实现批处理、通道对齐、广播扩展等需求。下面,我们从原理、API、示例到常见替代写法,带你深入理解 unsqueeze 的用法与场景。


PyTorch 中的 torch.matmul 详解与使用指南

在深度学习中,矩阵乘法是最常见的线性代数操作之一。PyTorch 为我们提供了多种矩阵乘法接口,其中最通用、灵活的就是 torch.matmul。本文将从概念、用法、性能及注意事项等角度,带你全面了解和掌握 torch.matmul


一、什么是 torch.matmul

torch.matmul 是 PyTorch 中用于矩阵(或张量)乘法的通用函数,它会根据输入张量的维度自动分配到不同的运算模式:

  1. 1-D × 1-D → 标量(点积)
  2. 2-D × 2-D → 矩阵乘法
  3. ≥3-D → 批量矩阵乘法(Batch Matrix Multiply),并支持广播
  4. 1-D 与 2-D 混合 → 向量-矩阵或矩阵-向量乘

相比于只支持二维的 torch.mm 或仅支持三维批量乘的 torch.bmmtorch.matmul 具有更高的通用性和灵活性。


二、基本用法示例

1. 向量点积(1-D × 1-D)

import torch

a = torch.tensor([1., 2., 3.])   # 形状 (3,)
b = torch.tensor([4., 5., 6.])   # 形状 (3,)
c = torch.matmul(a, b)           # 等同 torch.dot(a, b)
# 结果: 1*4 + 2*5 + 3*6 = 32.0
print(c)  # tensor(32.)

2. 二维矩阵乘法(2-D × 2-D)

A = torch.randn(2, 3)    # 2×3
B = torch.randn(3, 4)    # 3×4
C = torch.matmul(A, B)   # 结果 2×4
print(C.shape)           # torch.Size([2, 4])

3. 批量矩阵乘法(≥3-D)

A = torch.randn(10, 2, 3)   # (batch=10, 2, 3)
B = torch.randn(10, 3, 4)   # (batch=10, 3, 4)
C = torch.matmul(A, B)      # (10, 2, 4)
print(C.shape)              # torch.Size([10, 2, 4])

# 支持广播
A = torch.randn(5, 1, 2, 3)   # (5,1,2,3)
B = torch.randn(   4, 3, 4)   # (   4,3,4)
C = torch.matmul(A, B)        # 先广播到 (5,4,2,3) @ (5,4,3,4) → (5,4,2,4)
print(C.shape)                # torch.Size([5, 4, 2, 4])

4. 向量与矩阵混合

v = torch.randn(3)         # (3,)
M = torch.randn(3, 4)      # (3,4)
r1 = torch.matmul(v, M)    # 视作 (1×3) @ (3×4) → (4,)
r2 = torch.matmul(M, v)    # 视作 (3×4) @ (4×1) → (3,)
print(r1.shape, r2.shape)  # torch.Size([4]) torch.Size([3])

三、与 mmbmm 的区别

  • torch.mm(A, B)
    · 只能作用于二维矩阵,形状必须是 (m, p) × (p, n)
  • torch.bmm(A, B)
    · 只能作用于三维批量矩阵,形状 (b, m, p) × (b, p, n)
  • torch.matmul
    · 自动识别维度,从一维到高维都能处理,且支持广播策略。

如果你的场景仅仅是普通的二维乘法,也可以用 Python 语法糖 @

C = A @ B   # 等同于 torch.matmul(A, B) 或 torch.mm(A, B)

四、性能与数值稳定性

  • 底层实现torch.matmul 会根据输入维度,自动调用最优的 BLAS(CPU)或 cuBLAS(GPU)接口,性能与手动调用 mm / bmm 无差异。
  • 并行与向量化:在批量乘法时,充分利用硬件并行能力,速度远超 Python 层的循环。

Tip:如果你的张量非常稀疏,或拥有特定结构(比如对称、对角),可考虑使用稀疏矩阵乘法或专门的线性代数库,以获得更好性能。


五、典型应用场景

  1. Attention 机制:在 Transformer 及自注意力模块中,常见的 Q·Kᵀ、Attention 权重与 V 的乘法,都可用 torch.matmul 一行搞定。
  2. 图神经网络:批量节点特征与邻接矩阵的乘法,构成了图卷积的核心操作。
  3. RNN/GRU/LSTM:门控单元内部,各个状态与权重矩阵的乘法同样离不开 matmul
  4. 物理仿真与渲染:刚体坐标变换、相机投影矩阵与点云/图像数据的乘积,可直接调用 torch.matmul

六、注意事项

  • 确保 最后两个维度 的形状是可乘的,即 A.shape[-1] == B.shape[-2],否则会抛出维度不匹配的运行时错误。
  • 对于 非常大 的批量张量,请关注显存或内存占用,必要时可拆分成更小的子批次(mini-batch)。
  • 当你只做小型或一次性的矩阵乘法,用 CPU 速度往往足够;若需大规模并行计算,请务必使用 GPU 并行。

七、总结

  • torch.matmul 是一个“一键式”矩阵/张量乘法接口,支持从 1D 到 nD 的多种乘法规则。
  • 它能够自动分配到点积、普通矩阵乘或批量矩阵乘,并支持广播,极大简化了代码量。
  • 无论是构建 Transformer、RNN,还是图神经网络和物理仿真,都能依赖它完成高效的线性代数运算。

你可能感兴趣的:(pytorch小记(二十六):全面解读 PyTorch 的 `torch.matmul`)