前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。
MF-PSN(Multi-Feature Pyramid Stereo Network)是一个基于金字塔特征的多特征立体匹配网络,它通过构建多层次的特征金字塔来处理不同尺度的立体匹配问题。该项目中的关键组件包括两个最大池化层用于特征融合,这种设计旨在聚合多尺度信息但可能丢失一些细节特征。
TransXNet是一种新型的Transformer-CNN混合架构,它通过引入跨特征交互模块和动态位置偏置来增强特征的表达能力。相比传统的最大池化操作,TransXNet能够更有效地融合多尺度特征,同时保留更多的空间细节信息。
在将TransXNet模块引入MF-PSN网络以替代原有的最大池化特征融合方法后,出现了显存爆炸(GPU内存溢出)的问题。具体表现为:
经过初步分析,可能的原因包括:
TransXNet的自注意力机制:Transformer结构中的自注意力计算会生成大规模的中间矩阵(特别是QKV矩阵),这些矩阵的尺寸与输入分辨率的平方成正比。
特征图尺寸不匹配:TransXNet的输入输出维度可能与原MF-PSN网络不兼容,导致特征图尺寸意外扩大。
激活值保留:PyTorch默认会保留所有中间激活值用于梯度计算,而TransXNet的复杂结构会产生大量中间结果。
梯度计算开销:混合架构中CNN和Transformer组件的交互可能导致梯度计算图异常复杂。
在深度学习训练过程中,GPU内存主要被以下部分占用:
为了精确诊断内存问题,我们可以使用以下工具:
# PyTorch内存分析
import torch
print(torch.cuda.memory_allocated()) # 当前已分配内存
print(torch.cuda.max_memory_allocated()) # 最大分配内存
# 使用memory_profiler进行逐行分析
from memory_profiler import profile
@profile
def train_batch():
# 训练代码
pass
针对本问题的优化策略可以分为四个层次:
原始的TransXNet模块可能包含过多的参数和复杂的操作,我们可以进行以下调整:
class LiteTransXNet(nn.Module):
def __init__(self, in_channels, reduction_ratio=4):
super().__init__()
# 减少头的数量
self.num_heads = max(2, in_channels // 32)
# 添加下采样减少特征图尺寸
self.downsample = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=2, padding=1)
# 简化FFN网络
self.ffn = nn.Sequential(
nn.Conv2d(in_channels//2, in_channels//reduction_ratio, 1),
nn.GELU(),
nn.Conv2d(in_channels//reduction_ratio, in_channels//2, 1)
)
# 恢复原始尺寸
self.upsample = nn.ConvTranspose2d(in_channels//2, in_channels, kernel_size=3, stride=2, padding=1)
def forward(self, x):
residual = x
x = self.downsample(x)
# 简化的自注意力计算
b, c, h, w = x.shape
qkv = x.reshape(b, c, -1).permute(0, 2, 1)
attn = (qkv @ qkv.transpose(-2, -1)) * (c ** -0.5)
attn = attn.softmax(dim=-1)
x = (attn @ qkv).permute(0, 2, 1).reshape(b, c, h, w)
x = x + self.ffn(x)
x = self.upsample(x)
return x + residual
高分辨率特征图是内存消耗的主要原因之一,我们可以:
class FeatureCompressor(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1)
self.pool = nn.AdaptiveAvgPool2d((None, None)) # 保持空间维度
def forward(self, x):
return self.pool(self.conv(x))
将标准卷积替换为更高效的变体:
def conv_block(in_c, out_c, kernel_size=3