【北上广深杭大厂AI算法面试题】深度学习篇…Squeeze Excitation(SE)网络结构详解,附代码。(二)
欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文。详细信息可关注VX “
学术会议小灵通
”或参考学术信息专栏:https://blog.csdn.net/2401_89898861/article/details/145551342
SE 模块通常被嵌入到ResNet、MobileNet等网络中。例如,在 ResNet 残差块后面加上 SE 模块:
class SEResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, reduction=16):
super(SEResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# SE 模块
self.se = SqueezeExcitation(out_channels, reduction)
# 残差连接
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# 插入 SE 模块
out = self.se(out)
# 残差连接
out += self.shortcut(x)
return self.relu(out)
# 测试 SEResNet 残差块
x = torch.randn(1, 64, 32, 32) # 假设输入通道 64
se_res_block = SEResidualBlock(64, 64)
y = se_res_block(x)
print(y.shape) # 预计输出: torch.Size([1, 64, 32, 32])
SE 模块被广泛应用于:
在 ImageNet 上,SE-ResNet 在几乎不增加计算量的情况下,提升了 ResNet 的分类准确率。
在大厂 AI 面试中,SE 模块常涉及:
如果你准备大厂面试,一定要掌握 SE 模块的实现和优化思路!