torch tensor shape 从 3*,3 到 N,3,3。使用 repeat 而不要 expand

下面的代码会导致 报错同一个内存被多个索引使用。需要改成 repeat

batch_rotation_matrix = single_rotation_matrix.unsqueeze(0).expand(N, -1, -1)

修改之后,成功运行:

batch_rotation_matrix = single_rotation_matrix.repeat(N, 1, 1)

你可能感兴趣的:(python,深度学习,机器学习)