在使用 PyTorch 进行深度学习时,张量形状与广播机制常常是让初学者感到困惑的地方。我们需要时常面对多维张量,并在批量、通道、空间位置等多个维度之间做运算。
如果能熟练掌握各种维度变换操作——包括 unsqueeze
、expand
、view
/reshape
、transpose
/permute
等,可以帮助我们灵活地操纵张量,写出高效而简洁的矩阵化(vectorized)代码。本文将重点聚焦于 expand
以及与之密切相关的“维度扩展”技巧和其底层原理。
expand
?许多同学初次接触 PyTorch 的广播(broadcasting)时,往往只知道 +
、*
等运算符会自动广播。确实,在能自动广播的场合,可能不一定要手动去调用 expand
。但实际开发中,有些场景我们想精确控制哪一个维度被复制、如何被复制,就会手动用到 expand
:
(B, K, H*W)
的张量,在“anchor 样本 b”与“被遍历样本 b2”两大维度上进行运算,那么就会创建 (B, B, K, H*W)
的大张量。如果仅靠自动广播,有时不能一步到位。[batch_size, seq_len, embed_dim]
的张量复制到 [batch_size, num_heads, seq_len, embed_dim]
。可以先 unsqueeze
一个维度,再 expand
。[batch_size]
做逐元素操作,就会先 unsqueeze(0)
,然后 expand(batch_size)
,以与全 batch 配合。在上面的逻辑中,如果只依赖自动广播,可能会猜错要广播的维度或顺序,也可能出现形状不匹配的报错。此时 expand
能显式地告诉 PyTorch:“我想在某维度上复制这个张量”。
expand
的原理只允许在 size=1 的维度上扩展
expand
成更大的尺寸会报错。unsqueeze
,让它变成 size=1,然后再 expand
。并不真正创建新内存
expand
返回的是一个“视图(view)”,它只修改内部的 stride(步长)信息,多个位置共享同一块存储。repeat
会产生真的数据复制。配合广播(broadcasting)机制
expand
等价于手动触发“在一个维度做复制”。[B, K, H*W]
的张量扩充到 [B, B, K, H*W]
时,就可以先在前面 unsqueeze(0)
变成 [1, B, K, H*W]
,然后 expand(B, -1, -1, -1)
就搞定了。(B, K, H*W)
的张量扩展为 (B, B, K, H*W)
让我们用一个经典的“对比学习”或“实例判别损失”场景里的例子,说明如何借助 expand
来完成批量级别的并行计算。假设有:
logits.shape == (B, K, H*W) pos_mask_2d.shape == (B, K, H*W)
其中:
B
= batch sizeK
= 类别数H*W
= 空间分辨率的展平结果1. 把 (B, K, H*W)
的 logits
变成 (B, B, K, H*W)
pos_logits_4d = logits.unsqueeze(0) # => shape: (1, B, K, H*W)
pos_logits_4d = pos_logits_4d.expand(B, -1, -1, -1)
# => shape: (B, B, K, H*W)
解释:
unsqueeze(0)
在最前面插入一个维度,使其从 (B, K, H*W)
变成 (1, B, K, H*W)
;expand(B, -1, -1, -1)
把第 0 个维度从 1 扩张到 B,其它维度用 -1
表示“保持原大小不变”。pos_logits_4d[b, b2, k, hw]
就能同时区分“anchor 样本 b”与“被遍历的样本 b2”。2. 把 (B, K, H*W)
的 pos_mask_2d
变成 (B, B, K, H*W)
pos_mask_4d = pos_mask_2d.unsqueeze(1) # => shape: (B, 1, K, H*W)
pos_mask_4d = pos_mask_4d.expand(-1, B, -1, -1)
# => shape: (B, B, K, H*W)
解释:
(B, K, H*W)
-> (B, 1, K, H*W)
;(B, B, K, H*W)
;pos_mask_2d[b, k, hw] = True
),则在 pos_mask_4d[b, b2, k, hw]
整个 b2 维度上都保留 True;否则全是 False。这样一来,“anchor 样本 b 的正类”就可以在所有 b2 的 logits 里做筛选。和原先用双层循环 (for b in range(B): for hw in range(H*W): ...
) 实现的逻辑对照起来,是一模一样的功能,但是通过一次 expand
+masked_fill
就搞定了所有 b 和 hw 的遍历。
unsqueeze + expand
非常常见
维度含义要保持一致
pos_logits_4d
用第 0 维做 b,pos_mask_4d
用第 1 维做 b,就没法对齐。底层不复制数据,但要注意共享内存
expand
是视图,原地操作(in-place)时要谨慎。如果对原张量做了 .fill_()
或 .add_()
等,可能引发错误或改变“被 expand 出来的视图”的值。对于非常大维度的场合,要警惕显存
expand
本身不复制数据,但后续操作(比如把 (B, B, K, H*W)
做某种运算)可能会创建新的临时张量。B, K, H*W
都很大,会导致显存飙升。要留意是否可以做更精简的计算方式。expand
与其他维度变形操作
view
/ reshape
:主要改变形状,元素总数不变。squeeze
/ unsqueeze
:去除或插入维度=1 的轴。permute
/ transpose
:交换或重排维度顺序。repeat
:与 expand
功能相似,但它会真正复制数据,可能在某些特定场景用起来直观,但通常更耗内存。