pytorch代码实现之CBAM(Convolutional Block Attention Module)

import torch.nn as nn
import torch


class ChannelAttention(nn.Module):

    def __init__(self,in_channels):

        super(ChannelAttention,self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.maxpool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_channels,in_channels//16,kernel_size=1,padding=0,stride=1,bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channels//16,in_channels,kernel_size=1,padding=0,stride=1,bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        origial = x
        avg_out = self.fc2(self.relu(self.conv(self.avgpool(x))))
        max_out = self.fc2(self.relu(self.conv(self.maxpool(x))))
        out = avg_out + max_out
        out = self.sigmoid(out)
        out = origial * out

        return out

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention,

你可能感兴趣的:(pytorch,深度学习,神经网络,pycharm)