144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
from torch import nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
class PlainDecoder(nn.Module):
|
|
def __init__(self, cfg):
|
|
super(PlainDecoder, self).__init__()
|
|
self.cfg = cfg
|
|
|
|
self.dropout = nn.Dropout2d(0.1)
|
|
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.dropout(x)
|
|
x = self.conv8(x)
|
|
x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],
|
|
mode='bilinear', align_corners=False)
|
|
|
|
return x
|
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1):
|
|
"""1x1 convolution"""
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
|
|
|
|
class non_bottleneck_1d(nn.Module):
|
|
def __init__(self, chann, dropprob, dilated):
|
|
super().__init__()
|
|
|
|
self.conv3x1_1 = nn.Conv2d(
|
|
chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)
|
|
|
|
self.conv1x3_1 = nn.Conv2d(
|
|
chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)
|
|
|
|
self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
|
|
|
|
self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,
|
|
dilation=(dilated, 1))
|
|
|
|
self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,
|
|
dilation=(1, dilated))
|
|
|
|
self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
|
|
|
|
self.dropout = nn.Dropout2d(dropprob)
|
|
|
|
def forward(self, input):
|
|
output = self.conv3x1_1(input)
|
|
output = F.relu(output)
|
|
output = self.conv1x3_1(output)
|
|
output = self.bn1(output)
|
|
output = F.relu(output)
|
|
|
|
output = self.conv3x1_2(output)
|
|
output = F.relu(output)
|
|
output = self.conv1x3_2(output)
|
|
output = self.bn2(output)
|
|
|
|
if (self.dropout.p != 0):
|
|
output = self.dropout(output)
|
|
|
|
# +input = identity (residual connection)
|
|
return F.relu(output + input)
|
|
|
|
|
|
class UpsamplerBlock(nn.Module):
|
|
def __init__(self, ninput, noutput, up_width, up_height):
|
|
super().__init__()
|
|
|
|
self.conv = nn.ConvTranspose2d(
|
|
ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)
|
|
|
|
self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)
|
|
|
|
self.follows = nn.ModuleList()
|
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
|
self.follows.append(non_bottleneck_1d(noutput, 0, 1))
|
|
|
|
# interpolate
|
|
self.up_width = up_width
|
|
self.up_height = up_height
|
|
self.interpolate_conv = conv1x1(ninput, noutput)
|
|
self.interpolate_bn = nn.BatchNorm2d(
|
|
noutput, eps=1e-3, track_running_stats=True)
|
|
|
|
def forward(self, input):
|
|
output = self.conv(input)
|
|
output = self.bn(output)
|
|
out = F.relu(output)
|
|
for follow in self.follows:
|
|
out = follow(out)
|
|
|
|
interpolate_output = self.interpolate_conv(input)
|
|
interpolate_output = self.interpolate_bn(interpolate_output)
|
|
interpolate_output = F.relu(interpolate_output)
|
|
|
|
interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],
|
|
mode='bilinear', align_corners=False)
|
|
|
|
return out + interpolate
|
|
|
|
class BUSD(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
img_height = cfg.img_height
|
|
img_width = cfg.img_width
|
|
num_classes = cfg.num_classes
|
|
|
|
self.layers = nn.ModuleList()
|
|
|
|
self.layers.append(UpsamplerBlock(ninput=128, noutput=64,
|
|
up_height=int(img_height)//4, up_width=int(img_width)//4))
|
|
self.layers.append(UpsamplerBlock(ninput=64, noutput=32,
|
|
up_height=int(img_height)//2, up_width=int(img_width)//2))
|
|
self.layers.append(UpsamplerBlock(ninput=32, noutput=16,
|
|
up_height=int(img_height)//1, up_width=int(img_width)//1))
|
|
self.out1 = conv1x1(128, 64)
|
|
self.out2 = conv1x1(64, 32)
|
|
self.output_conv = conv1x1(16, num_classes)
|
|
|
|
|
|
def forward(self, input):
|
|
out1 = input[0]
|
|
out2 = input[1]
|
|
output = input[2]
|
|
|
|
for i,layer in enumerate(self.layers):
|
|
if i == 0:
|
|
output = layer(output)
|
|
output = torch.cat((out2, output), dim=1)
|
|
output = self.out1(output)
|
|
elif i == 1:
|
|
output = layer(output)
|
|
output = torch.cat((out1, output), dim=1)
|
|
output = self.out2(output)
|
|
else:
|
|
output = layer(output)
|
|
|
|
output = self.output_conv(output)
|
|
|
|
return output
|