109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
|
'''Convolutional Block Attention Module (CBAM)
|
||
|
'''
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn.modules import pooling
|
||
|
from torch.nn.modules.flatten import Flatten
|
||
|
|
||
|
|
||
|
|
||
|
class Channel_Attention(nn.Module):
|
||
|
'''Channel Attention in CBAM.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max']):
|
||
|
'''Param init and architecture building.
|
||
|
'''
|
||
|
|
||
|
super(Channel_Attention, self).__init__()
|
||
|
self.pool_types = pool_types
|
||
|
|
||
|
self.shared_mlp = nn.Sequential(
|
||
|
nn.Flatten(),
|
||
|
nn.Linear(in_features=channel_in, out_features=channel_in//reduction_ratio),
|
||
|
nn.ReLU(inplace=True),
|
||
|
nn.Linear(in_features=channel_in//reduction_ratio, out_features=channel_in)
|
||
|
)
|
||
|
|
||
|
|
||
|
def forward(self, x):
|
||
|
'''Forward Propagation.
|
||
|
'''
|
||
|
|
||
|
channel_attentions = []
|
||
|
|
||
|
for pool_types in self.pool_types:
|
||
|
if pool_types == 'avg':
|
||
|
pool_init = nn.AvgPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
||
|
avg_pool = pool_init(x)
|
||
|
channel_attentions.append(self.shared_mlp(avg_pool))
|
||
|
elif pool_types == 'max':
|
||
|
pool_init = nn.MaxPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
||
|
max_pool = pool_init(x)
|
||
|
channel_attentions.append(self.shared_mlp(max_pool))
|
||
|
|
||
|
pooling_sums = torch.stack(channel_attentions, dim=0).sum(dim=0)
|
||
|
scaled = nn.Sigmoid()(pooling_sums).unsqueeze(2).unsqueeze(3).expand_as(x)
|
||
|
|
||
|
return x * scaled #return the element-wise multiplication between the input and the result.
|
||
|
|
||
|
|
||
|
class ChannelPool(nn.Module):
|
||
|
'''Merge all the channels in a feature map into two separate channels where the first channel is produced by taking the max values from all channels, while the
|
||
|
second one is produced by taking the mean from every channel.
|
||
|
'''
|
||
|
def forward(self, x):
|
||
|
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
|
||
|
|
||
|
|
||
|
class Spatial_Attention(nn.Module):
|
||
|
'''Spatial Attention in CBAM.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, kernel_size=7):
|
||
|
'''Spatial Attention Architecture.
|
||
|
'''
|
||
|
|
||
|
super(Spatial_Attention, self).__init__()
|
||
|
|
||
|
self.compress = ChannelPool()
|
||
|
self.spatial_attention = nn.Sequential(
|
||
|
nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=1, dilation=1, padding=(kernel_size-1)//2, bias=False),
|
||
|
nn.BatchNorm2d(num_features=1, eps=1e-5, momentum=0.01, affine=True)
|
||
|
)
|
||
|
|
||
|
|
||
|
def forward(self, x):
|
||
|
'''Forward Propagation.
|
||
|
'''
|
||
|
x_compress = self.compress(x)
|
||
|
x_output = self.spatial_attention(x_compress)
|
||
|
scaled = nn.Sigmoid()(x_output)
|
||
|
return x * scaled
|
||
|
|
||
|
|
||
|
class CBAM(nn.Module):
|
||
|
'''CBAM architecture.
|
||
|
'''
|
||
|
def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max'], spatial=True):
|
||
|
'''Param init and arch build.
|
||
|
'''
|
||
|
super(CBAM, self).__init__()
|
||
|
self.spatial = spatial
|
||
|
|
||
|
self.channel_attention = Channel_Attention(channel_in=channel_in, reduction_ratio=reduction_ratio, pool_types=pool_types)
|
||
|
|
||
|
if self.spatial:
|
||
|
self.spatial_attention = Spatial_Attention(kernel_size=7)
|
||
|
|
||
|
|
||
|
def forward(self, x):
|
||
|
'''Forward Propagation.
|
||
|
'''
|
||
|
x_out = self.channel_attention(x)
|
||
|
if self.spatial:
|
||
|
x_out = self.spatial_attention(x_out)
|
||
|
|
||
|
return x_out
|