brain-tumor_image_classific.../cbam.py

109 lines
3.5 KiB
Python

'''Convolutional Block Attention Module (CBAM)
'''
import torch
import torch.nn as nn
from torch.nn.modules import pooling
from torch.nn.modules.flatten import Flatten
class Channel_Attention(nn.Module):
'''Channel Attention in CBAM.
'''
def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max']):
'''Param init and architecture building.
'''
super(Channel_Attention, self).__init__()
self.pool_types = pool_types
self.shared_mlp = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=channel_in, out_features=channel_in//reduction_ratio),
nn.ReLU(inplace=True),
nn.Linear(in_features=channel_in//reduction_ratio, out_features=channel_in)
)
def forward(self, x):
'''Forward Propagation.
'''
channel_attentions = []
for pool_types in self.pool_types:
if pool_types == 'avg':
pool_init = nn.AvgPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
avg_pool = pool_init(x)
channel_attentions.append(self.shared_mlp(avg_pool))
elif pool_types == 'max':
pool_init = nn.MaxPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
max_pool = pool_init(x)
channel_attentions.append(self.shared_mlp(max_pool))
pooling_sums = torch.stack(channel_attentions, dim=0).sum(dim=0)
scaled = nn.Sigmoid()(pooling_sums).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scaled #return the element-wise multiplication between the input and the result.
class ChannelPool(nn.Module):
'''Merge all the channels in a feature map into two separate channels where the first channel is produced by taking the max values from all channels, while the
second one is produced by taking the mean from every channel.
'''
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class Spatial_Attention(nn.Module):
'''Spatial Attention in CBAM.
'''
def __init__(self, kernel_size=7):
'''Spatial Attention Architecture.
'''
super(Spatial_Attention, self).__init__()
self.compress = ChannelPool()
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=1, dilation=1, padding=(kernel_size-1)//2, bias=False),
nn.BatchNorm2d(num_features=1, eps=1e-5, momentum=0.01, affine=True)
)
def forward(self, x):
'''Forward Propagation.
'''
x_compress = self.compress(x)
x_output = self.spatial_attention(x_compress)
scaled = nn.Sigmoid()(x_output)
return x * scaled
class CBAM(nn.Module):
'''CBAM architecture.
'''
def __init__(self, channel_in, reduction_ratio=16, pool_types=['avg', 'max'], spatial=True):
'''Param init and arch build.
'''
super(CBAM, self).__init__()
self.spatial = spatial
self.channel_attention = Channel_Attention(channel_in=channel_in, reduction_ratio=reduction_ratio, pool_types=pool_types)
if self.spatial:
self.spatial_attention = Spatial_Attention(kernel_size=7)
def forward(self, x):
'''Forward Propagation.
'''
x_out = self.channel_attention(x)
if self.spatial:
x_out = self.spatial_attention(x_out)
return x_out