Pytorch 的 Tensor 用法
官方解释:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_add#torch.Tensor.scatter_add_
函数参数:scatter_add_
(dim, indexTensor, otherTensor) → 输出Tensor
函数用法:selfTensor.scatter_add_
(dim, indexTensor, otherTensor)
要求:
self
, index
and other
should have same number of dimensions.index.size(d) <= other.size(d)
for all dimensions d
index.size(d) <= self.size(d)
for all dimensions d != dim
.gather()
, the values of index
must be between 0
and self.size(dim) - 1
dim
must be unique.示例代码:final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
该函数将 otherTensor 的所有值加到 selfTensor 中,加入位置由 indexTensor 指明。
self[ index[i][j][k] ][ j ][ k ] += other[ i ][ j ][ k ] # if dim == 0