'''Convolutional Block Attention Module (CBAM) ''' import torch import torch.nn as nn from torch.nn.modules import pooling from torch.nn.modules.flatten import Flatten class Channel_Attention(nn.Module): '''Channel Attention in CBAM. ''' def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max']): '''Param init and architecture building. ''' super(Channel_Attention, self).__init__() self.pool_types = pool_types self.shared_mlp = nn.Sequential( nn.Flatten(), nn.Linear(in_features=channel_in, out_features=channel_in//reduction_ratio), nn.ReLU(inplace=True), nn.Linear(in_features=channel_in//reduction_ratio, out_features=channel_in) ) def forward(self, x): '''Forward Propagation. ''' channel_attentions = [] for pool_types in self.pool_types: if pool_types == 'avg': pool_init = nn.AvgPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) avg_pool = pool_init(x) channel_attentions.append(self.shared_mlp(avg_pool)) elif pool_types == 'max': pool_init = nn.MaxPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) max_pool = pool_init(x) channel_attentions.append(self.shared_mlp(max_pool)) pooling_sums = torch.stack(channel_attentions, dim=0).sum(dim=0) scaled = nn.Sigmoid()(pooling_sums).unsqueeze(2).unsqueeze(3).expand_as(x) return x * scaled #return the element-wise multiplication between the input and the result. class ChannelPool(nn.Module): '''Merge all the channels in a feature map into two separate channels where the first channel is produced by taking the max values from all channels, while the second one is produced by taking the mean from every channel. ''' def forward(self, x): return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) class Spatial_Attention(nn.Module): '''Spatial Attention in CBAM. ''' def __init__(self, kernel_size=7): '''Spatial Attention Architecture. ''' super(Spatial_Attention, self).__init__() self.compress = ChannelPool() self.spatial_attention = nn.Sequential( nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=1, dilation=1, padding=(kernel_size-1)//2, bias=False), nn.BatchNorm2d(num_features=1, eps=1e-5, momentum=0.01, affine=True) ) def forward(self, x): '''Forward Propagation. ''' x_compress = self.compress(x) x_output = self.spatial_attention(x_compress) scaled = nn.Sigmoid()(x_output) return x * scaled class CBAM(nn.Module): '''CBAM architecture. ''' def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max'], spatial=True): '''Param init and arch build. ''' super(CBAM, self).__init__() self.spatial = spatial self.channel_attention = Channel_Attention(channel_in=channel_in, reduction_ratio=reduction_ratio, pool_types=pool_types) if self.spatial: self.spatial_attention = Spatial_Attention(kernel_size=7) def forward(self, x): '''Forward Propagation. ''' x_out = self.channel_attention(x) if self.spatial: x_out = self.spatial_attention(x_out) return x_out