PyTorch学习之:torch.gather是什么?

torch.gather的定义:

torch.gather 是 PyTorch 中的一个张量操作函数,其作用是根据指定的维度dim)和索引张量index),从输入张量(input)中收集元素,生成一个与索引张量形状相同的输出张量。总体来说,就是维度dim和索引张量index决定一个收集数的规则,然后,基于这个规则从输入张量中获取需要的元素。

核心部分:

1.输入张量input):

  • 任意形状的张量。

2.索引张量index):

  • 形状必须与输入张量在除 dim 外的其他维度上一致。

  • 索引值必须在输入张量 dim 维度的有效范围内(即 0 到 size(dim)-1)。

3.输出张量output):

  • 形状与索引张量相同。

  • 每个元素的值由以下规则确定:

output[i][j][k] = input[i][index[i][j][k]][k]  # 当 dim=1 时

举例详解:

示例 1:二维张量,dim=1

import torch

input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)

output = torch.gather(input, dim=1, index=index)
print(output)

输出

tensor([[1, 1],
        [4, 3]])

 解释

输入是一个2x2的矩阵,因为dim是1,所以我们参考下面的公式:

output[i][j] = input[i][index[i][j]]  # 当 dim=1 时

对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][1](即为4)。

对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[1][0](即为3)。

所以,最后的结果为:

tensor([[1, 1],
        [4, 3]])

 示例 2:二维张量,dim=0

import torch

input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]], dtype=torch.long)

output = torch.gather(input, dim=0, index=index)
print(output)

输出

tensor([[1, 2],
        [3, 2]])

 解释

输入是一个2x2的矩阵,因为dim是0,所以我们参考下面的公式:

output[i][j] = input[index[i][j]][j]  # 当 dim=0 时

对于输出的第0行第0列(i = 0, j = 0),index对应的位置为0(因为index[0][0]为0),所以,对应的输出等于input[0][0](即为1)。

对于输出的第0行第1列(i = 0, j = 1),index对应的位置为0(因为index[0][1]为0),所以,对应的输出等于input[0][1](即为2)。

对于输出的第1行第0列(i = 1, j = 0),index对应的位置为1(因为index[1][0]为1),所以,对应的输出等于input[1][0](即为3)。

对于输出的第1行第1列(i = 1, j = 1),index对应的位置为0(因为index[1][1]为0),所以,对应的输出等于input[0][1](即为2)。

所以,最后的结果为:

tensor([[1, 2],
        [3, 2]])

你可能感兴趣的:(AI/AGI,NLP/LLMs,Computer,knowledge,pytorch,学习,人工智能,python)