from torch import nn import torch.nn.functional as F import torch class PlainDecoder(nn.Module): def __init__(self, cfg): super(PlainDecoder, self).__init__() self.cfg = cfg self.dropout = nn.Dropout2d(0.1) self.conv8 = nn.Conv2d(128, cfg.num_classes, 1) def forward(self, x): x = self.dropout(x) x = self.conv8(x) x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width], mode='bilinear', align_corners=False) return x def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class non_bottleneck_1d(nn.Module): def __init__(self, chann, dropprob, dilated): super().__init__() self.conv3x1_1 = nn.Conv2d( chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) self.conv1x3_1 = nn.Conv2d( chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True, dilation=(dilated, 1)) self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True, dilation=(1, dilated)) self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) self.dropout = nn.Dropout2d(dropprob) def forward(self, input): output = self.conv3x1_1(input) output = F.relu(output) output = self.conv1x3_1(output) output = self.bn1(output) output = F.relu(output) output = self.conv3x1_2(output) output = F.relu(output) output = self.conv1x3_2(output) output = self.bn2(output) if (self.dropout.p != 0): output = self.dropout(output) # +input = identity (residual connection) return F.relu(output + input) class UpsamplerBlock(nn.Module): def __init__(self, ninput, noutput, up_width, up_height): super().__init__() self.conv = nn.ConvTranspose2d( ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True) self.follows = nn.ModuleList() self.follows.append(non_bottleneck_1d(noutput, 0, 1)) self.follows.append(non_bottleneck_1d(noutput, 0, 1)) # interpolate self.up_width = up_width self.up_height = up_height self.interpolate_conv = conv1x1(ninput, noutput) self.interpolate_bn = nn.BatchNorm2d( noutput, eps=1e-3, track_running_stats=True) def forward(self, input): output = self.conv(input) output = self.bn(output) out = F.relu(output) for follow in self.follows: out = follow(out) interpolate_output = self.interpolate_conv(input) interpolate_output = self.interpolate_bn(interpolate_output) interpolate_output = F.relu(interpolate_output) interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width], mode='bilinear', align_corners=False) return out + interpolate class BUSD(nn.Module): def __init__(self, cfg): super().__init__() img_height = cfg.img_height img_width = cfg.img_width num_classes = cfg.num_classes self.layers = nn.ModuleList() self.layers.append(UpsamplerBlock(ninput=128, noutput=64, up_height=int(img_height)//4, up_width=int(img_width)//4)) self.layers.append(UpsamplerBlock(ninput=128, noutput=64, up_height=int(img_height)//2, up_width=int(img_width)//2)) self.layers.append(UpsamplerBlock(ninput=64, noutput=32, up_height=int(img_height)//1, up_width=int(img_width)//1)) self.output_conv = conv1x1(32, num_classes) def forward(self, input): x = input[0] output = input[1] for i,layer in enumerate(self.layers): output = layer(output) if i == 0: output = torch.cat((x, output), dim=1) output = self.output_conv(output) return output