在PyTorch中,view
函数是一个极为重要的张量操作函数,其主要功能是对张量的形状进行重塑,与此同时会尽力维持张量中元素的总数不变。
view
函数的主要作用是改变张量的维度和大小,不过要保证重塑前后张量的元素总数相同。其语法格式如下:
tensor.view(*args)
这里的*args
代表的是新的形状,它既可以是一个元组,也可以是多个用逗号分隔的整数。
x = torch.randn(2, 3, 4) # 此时x的形状为[2, 3, 4]
y = x.view(2, 12) # y的形状变为[2, 12]
x = torch.randn(6) # x的形状是[6]
y = x.view(2, 3) # y的形状变为[2, 3]
当在形状参数里使用-1
时,PyTorch会依据张量元素的总数以及其他维度的大小,自动推算出-1
所对应的维度值。
x = torch.randn(2, 3, 4) # x的形状为[2, 3, 4]
y = x.view(2, -1) # y的形状是[2, 12]
z = x.view(-1, 3, 2) # z的形状为[4, 3, 2]
view
函数要求输入的张量必须是内存连续的。要是张量在内存中不连续,就需要先调用contiguous()
函数。
x = torch.randn(2, 3)
y = x.t() # 对x进行转置操作,此时y在内存中不再连续
z = y.contiguous().view(3, 2) # 先调用contiguous(),再使用view
view
函数:必须在张量内存连续的情况下才能使用,不过它能保证返回的是原张量的视图,这意味着不会进行数据拷贝,从而可以提升内存使用效率。reshape
函数:无论张量内存是否连续都能使用,它可能会返回原张量的视图,也可能会进行数据拷贝。view
函数返回的是原张量的视图,而非新的张量。这就表明对视图进行修改时,原张量也会随之改变。
x = torch.tensor([1, 2, 3, 4])
y = x.view(2, 2)
y[0, 0] = 100
print(x[0]) # 输出结果为100
x = torch.randn(2, 2, 2, 2) # x的形状为[2, 2, 2, 2]
y = x.view(2, 8) # y的形状变为[2, 8]
z = x.view(-1) # z是一个一维张量,形状为[16]
view
函数时,新形状的元素总数必须和原张量的元素总数相等。contiguous()
函数。view
函数在大多数情况下比reshape
函数的性能要好,但在使用时需要更加谨慎。在深度学习模型里,view
函数经常被用于调整输入或输出的形状,像在全连接层和卷积层之间进行过渡时就会用到。
# 模拟一个CNN输出
x = torch.randn(16, 3, 28, 28) # 批量大小为16,3个通道,28×28的图像
y = x.view(16, -1) # 将特征图展平为一维向量,形状变为[16, 2352]
在PyTorch中,view
和reshape
都用于改变张量的形状,但它们在功能、内存管理和使用场景上存在关键差异。以下是详细比较:
特性 | view | reshape |
---|---|---|
内存连续性要求 | 必须内存连续 (contiguous ) |
无要求,自动处理非连续张量 |
返回类型 | 始终返回原张量的视图(不复制数据) | 可能返回视图或副本(取决于是否需要复制数据) |
异常处理 | 若张量不连续,抛出 RuntimeError |
自动调用 contiguous() 避免错误 |
x = torch.randn(2, 3)
y = x.t() # 转置操作使 y 不连续
z = y.view(6) # 报错:RuntimeError
z = y.contiguous().view(6) # 正确:先转为连续张量
x = torch.randn(2, 3)
y = x.t()
z = y.reshape(6) # 等价于 y.contiguous().view(6),自动处理连续性
x = torch.tensor([1, 2, 3, 4])
y = x.view(2, 2)
y[0, 0] = 100
print(x[0]) # 输出: 100(原张量被修改)
x = torch.tensor([1, 2, 3, 4])
y = x.t().reshape(2, 2) # 因转置导致非连续,reshape 可能拷贝数据
y[0, 0] = 100
print(x[0]) # 输出: 1(原张量未被修改,reshape 创建了副本)
建议:若需确保性能且张量连续,优先使用 view
;若不确定连续性或追求代码简洁,使用 reshape
。
场景 | 推荐函数 | 原因 |
---|---|---|
已知张量连续且需高效操作 | view |
避免不必要的拷贝 |
处理可能非连续的张量 | reshape |
自动处理连续性,避免错误 |
深度学习模型中的固定操作 | view |
如 CNN 到全连接层的张量展平 |
快速原型开发或代码简化 | reshape |
减少 contiguous() 调用 |
两者均支持 -1
作为占位符,PyTorch 会自动计算该维度的大小:
x = torch.randn(2, 3, 4)
y = x.view(2, -1) # y.shape: [2, 12]
z = x.reshape(-1, 3, 2) # z.shape: [4, 3, 2]
功能 | view | reshape |
---|---|---|
内存连续性要求 | 必须连续 | 无要求 |
是否保证零拷贝 | 是 | 否(可能拷贝) |
处理非连续张量 | 需手动调用 contiguous() |
自动处理 |
代码简洁性 | 较低(需关注连续性) | 较高 |
性能 | 高 | 中等(可能有拷贝开销) |
view
:当张量连续性已知且需确保性能时(如模型前向传播)。reshape
:在数据处理或不确定张量连续性时,避免 RuntimeError
。view
报错,检查张量是否连续(使用 tensor.is_contiguous()
)。