torch.gather
是 PyTorch 中的一个张量操作函数,其作用是根据指定的维度(dim
)和索引张量(index
),从输入张量(input
)中收集元素,生成一个与索引张量形状相同的输出张量。总体来说,就是维度dim和索引张量index决定一个收集数的规则,然后,基于这个规则从输入张量中获取需要的元素。
input
):index
):形状必须与输入张量在除 dim
外的其他维度上一致。
索引值必须在输入张量 dim
维度的有效范围内(即 0
到 size(dim)-1
)。
output
):形状与索引张量相同。
每个元素的值由以下规则确定:
output[i][j][k] = input[i][index[i][j][k]][k] # 当 dim=1 时
dim=1
import torch
input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)
output = torch.gather(input, dim=1, index=index)
print(output)
tensor([[1, 1],
[4, 3]])
输入是一个2x2的矩阵,因为dim是1,所以我们参考下面的公式:
output[i][j] = input[i][index[i][j]] # 当 dim=1 时
对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。
对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][0](即为1)。
对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][1](即为4)。
对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[1][0](即为3)。
所以,最后的结果为:
tensor([[1, 1],
[4, 3]])
dim=0
import torch
input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)
output = torch.gather(input, dim=0, index=index)
print(output)
tensor([[1, 2],
[3, 2]])
输入是一个2x2的矩阵,因为dim是0,所以我们参考下面的公式:
output[i][j] = input[index[i][j]][j] # 当 dim=0 时
对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。
对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][1](即为2)。
对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][0](即为3)。
对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[0][1](即为2)。
所以,最后的结果为:
tensor([[1, 2],
[3, 2]])