torch.einsum
是 PyTorch 中非常强大的函数,用于 灵活定义张量的乘法、求和、转置、点乘、外积等各种线性代数操作。它源自 Einstein Summation(爱因斯坦求和约定),是一种更紧凑、可读性更强的多维操作方式。
torch.einsum(equation, *operands)
equation
: 字符串,定义操作规则。operands
: 一个或多个 Tensor,参与运算的张量。在 equation
中:
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([0.1, 0.2, 0.3])
res = torch.einsum('i,i->', a, b)
print(res) # 输出:1.4
解释:i,i->
表示对所有 i
求乘积再求和。
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B)
print(C.shape) # torch.Size([2, 4])
相当于 torch.matmul(A, B)
A = torch.randn(3, 4)
x = torch.randn(4)
y = torch.einsum('ij,j->i', A, x)
print(y.shape) # torch.Size([3])
A = torch.randn(5, 2, 3) # batch_size=5
B = torch.randn(5, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
print(C.shape) # torch.Size([5, 2, 4])
A = torch.randn(3, 4)
B = torch.einsum('ij->ji', A)
print(B.shape) # torch.Size([4, 3])
a = torch.tensor([1.0, 2.0])
b = torch.tensor([10.0, 20.0, 30.0])
res = torch.einsum('i,j->ij', a, b)
print(res)
# 输出形状为 (2, 3)
x = torch.randn(32, 128) # batch_size=32, dim=128
norm_sq = torch.einsum('bi,bi->b', x, x)
q = torch.randn(32, 8, 64) # query: batch, head, dim
k = torch.randn(32, 8, 64)
attn_scores = torch.einsum('bhd,bhd->bh', q, k)
图邻接矩阵传播(Graph Adjacency Propagation)通常用于图神经网络(GNN)中,表示将节点特征通过邻接结构进行消息传递(message passing)或特征聚合(feature aggregation)的过程。
我们将使用 torch.einsum
实现这个过程。
设:
import torch
# 假设图有 4 个节点,每个节点有 3 维特征
X = torch.tensor([[1.0, 0.5, 2.0],
[0.3, 1.2, 0.7],
[0.8, 0.1, 1.1],
[0.0, 0.3, 0.4]]) # (4, 3)
# 邻接矩阵 A(可为稀疏或归一化矩阵)
A = torch.tensor([[1, 1, 0, 0],
[1, 1, 1, 0],
[0, 1, 1, 1],
[0, 0, 1, 1]], dtype=torch.float32) # (4, 4)
# 使用 torch.einsum 进行图传播 AX
# 'ij,jk->ik':A(i,j) * X(j,k) → 输出 (i,k)
X_agg = torch.einsum('ij,jk->ik', A, X)
print("聚合后的特征:")
print(X_agg)
若有可学习的线性层 W ∈ R F × F ′ W \in \mathbb{R}^{F \times F'} W∈RF×F′,传播过程变为:
A X W AXW AXW
代码示例:
# W 是可学习的线性变换
W = torch.nn.Linear(in_features=3, out_features=2, bias=False)
# XW: 节点特征线性变换
X_transformed = W(X) # (4, 2)
# 再传播
X_out = torch.einsum('ij,jk->ik', A, X_transformed)
print("传播后的新特征维度:", X_out.shape)
设:
X
: (B, N, F) — batch_size 个图,每图 N 节点 F 维特征A
: (B, N, N) — batch 的邻接矩阵# B 个图,每图 4 节点,每个节点 3 维特征
X = torch.randn(8, 4, 3) # (B, N, F)
A = torch.eye(4).repeat(8, 1, 1) # (B, N, N)
# 批量邻接传播
X_agg = torch.einsum('bij,bjf->bif', A, X) # 输出 (B, N, F)
操作 | einsum 表达 |
说明 |
---|---|---|
图传播(AX) | 'ij,jk->ik' |
基础图邻接传播 |
批量传播 | 'bij,bjf->bif' |
批次图传播 |
权重传播 | AX @ W or 'ij,jk->ik' then Linear |
加入特征变换 |
einsum
?permute
+ view
+ matmul
;einsum_path
可用于优化路径选择:torch.einsum_path('bij,bjk->bik', A, B, optimize='optimal')
功能 | 示例公式 |
---|---|
向量点积 | 'i,i->' |
矩阵乘法 | 'ik,kj->ij' |
外积 | 'i,j->ij' |
转置 | 'ij->ji' |
batch matmul | 'bij,bjk->bik' |
L2 norm square | 'bi,bi->b' |