198 lines
8.1 KiB
Python
198 lines
8.1 KiB
Python
# 2022.06.17-Changed for building ViG model
|
|
# Huawei Technologies Co., Ltd. <foss@huawei.com>
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from .torch_nn import BasicConv, batched_index_select, act_layer
|
|
from .torch_edge import DenseDilatedKnnGraph
|
|
from .pos_embed import get_2d_relative_pos_embed
|
|
import torch.nn.functional as F
|
|
from .DropPath import DropPath
|
|
|
|
|
|
class MRConv2d(nn.Module):
|
|
"""
|
|
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
|
|
"""
|
|
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
|
super(MRConv2d, self).__init__()
|
|
self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
# print('x.shape : ', x.shape)
|
|
# print('edge_index.shape : ', edge_index.shape)
|
|
x_i = batched_index_select(x, edge_index[1])
|
|
if y is not None:
|
|
x_j = batched_index_select(y, edge_index[0])
|
|
else:
|
|
x_j = batched_index_select(x, edge_index[0])
|
|
# print('x_i.shape : ', x_i.shape)
|
|
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
|
|
# print('x_j.shape : ', x_j.shape)
|
|
b, c, n, _ = x.shape
|
|
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
|
|
# print('x_.shape : ', x.shape)
|
|
return self.nn(x)
|
|
|
|
|
|
class EdgeConv2d(nn.Module):
|
|
"""
|
|
Edge convolution layer (with activation, batch normalization) for dense data type
|
|
"""
|
|
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
|
super(EdgeConv2d, self).__init__()
|
|
self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
x_i = batched_index_select(x, edge_index[1])
|
|
if y is not None:
|
|
x_j = batched_index_select(y, edge_index[0])
|
|
else:
|
|
x_j = batched_index_select(x, edge_index[0])
|
|
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
|
|
return max_value
|
|
|
|
|
|
class GraphSAGE(nn.Module):
|
|
"""
|
|
GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) for dense data type
|
|
"""
|
|
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
|
super(GraphSAGE, self).__init__()
|
|
self.nn1 = BasicConv([in_channels, in_channels], act, norm, bias)
|
|
self.nn2 = BasicConv([in_channels*2, out_channels], act, norm, bias)
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
if y is not None:
|
|
x_j = batched_index_select(y, edge_index[0])
|
|
else:
|
|
x_j = batched_index_select(x, edge_index[0])
|
|
x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True)
|
|
return self.nn2(torch.cat([x, x_j], dim=1))
|
|
|
|
|
|
class GINConv2d(nn.Module):
|
|
"""
|
|
GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for dense data type
|
|
"""
|
|
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
|
super(GINConv2d, self).__init__()
|
|
self.nn = BasicConv([in_channels, out_channels], act, norm, bias)
|
|
eps_init = 0.0
|
|
self.eps = nn.Parameter(torch.Tensor([eps_init]))
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
if y is not None:
|
|
x_j = batched_index_select(y, edge_index[0])
|
|
else:
|
|
x_j = batched_index_select(x, edge_index[0])
|
|
x_j = torch.sum(x_j, -1, keepdim=True)
|
|
return self.nn((1 + self.eps) * x + x_j)
|
|
|
|
|
|
class GraphConv2d(nn.Module):
|
|
"""
|
|
Static graph convolution layer
|
|
"""
|
|
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True):
|
|
super(GraphConv2d, self).__init__()
|
|
if conv == 'edge':
|
|
self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias)
|
|
elif conv == 'mr':
|
|
self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias)
|
|
elif conv == 'sage':
|
|
self.gconv = GraphSAGE(in_channels, out_channels, act, norm, bias)
|
|
elif conv == 'gin':
|
|
self.gconv = GINConv2d(in_channels, out_channels, act, norm, bias)
|
|
else:
|
|
raise NotImplementedError('conv:{} is not supported'.format(conv))
|
|
|
|
def forward(self, x, edge_index, y=None):
|
|
return self.gconv(x, edge_index, y)
|
|
|
|
|
|
class DyGraphConv2d(GraphConv2d):
|
|
"""
|
|
Dynamic graph convolution layer
|
|
"""
|
|
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
|
|
norm=None, bias=True, stochastic=False, epsilon=0.0, r=1):
|
|
super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias)
|
|
self.k = kernel_size
|
|
self.d = dilation
|
|
self.r = r
|
|
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
|
|
|
|
def forward(self, x, relative_pos=None):
|
|
B, C, H, W = x.shape
|
|
y = None
|
|
if self.r > 1:
|
|
y = F.avg_pool2d(x, self.r, self.r)
|
|
y = y.reshape(B, C, -1, 1).contiguous()
|
|
x = x.reshape(B, C, -1, 1).contiguous()
|
|
# 选出最邻近的几个节点,以用于后续的聚合操作
|
|
edge_index = self.dilated_knn_graph(x, y, relative_pos)
|
|
# print('x.shape : ', x.shape)
|
|
# print('y.shape : ', y.shape)
|
|
# print('edge_index.shape : ', edge_index[0][0][0])
|
|
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
|
|
# print('x.shape : ', x.shape)
|
|
return x.reshape(B, -1, H, W).contiguous()
|
|
|
|
|
|
class Grapher(nn.Module):
|
|
"""
|
|
Grapher module with graph convolution and fc layers
|
|
"""
|
|
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
|
|
bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
|
|
super(Grapher, self).__init__()
|
|
self.channels = in_channels
|
|
self.n = n
|
|
self.r = r
|
|
self.fc1 = nn.Sequential(
|
|
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
|
|
nn.BatchNorm2d(in_channels),
|
|
)
|
|
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
|
|
act, norm, bias, stochastic, epsilon, r)
|
|
self.fc2 = nn.Sequential(
|
|
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
|
|
nn.BatchNorm2d(in_channels),
|
|
)
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
self.relative_pos = None
|
|
if relative_pos:
|
|
# print('using relative_pos')
|
|
# print('in_channels = ', in_channels)
|
|
relative_pos_tensor = torch.from_numpy(np.float32(get_2d_relative_pos_embed(in_channels,
|
|
int(n**0.5)))).unsqueeze(0).unsqueeze(1)
|
|
# print('relative_pos_tensor.shape : ', relative_pos_tensor.shape)
|
|
relative_pos_tensor = F.interpolate(
|
|
relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
|
|
# print('after interpolate, relative_pos_tensor.shape : ', relative_pos_tensor.shape)
|
|
self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False)
|
|
# print('finally, relative_pos_tensor.shape : ', relative_pos_tensor.shape)
|
|
|
|
def _get_relative_pos(self, relative_pos, H, W):
|
|
if relative_pos is None or H * W == self.n:
|
|
# print('relative_pos is None or H * W == self.n')
|
|
return relative_pos
|
|
else:
|
|
N = H * W
|
|
N_reduced = N // (self.r * self.r)
|
|
return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)
|
|
|
|
def forward(self, x):
|
|
_tmp = x
|
|
x = self.fc1(x)
|
|
B, C, H, W = x.shape
|
|
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
|
|
# print('x.shape = ', x.shape)
|
|
# print('relative_pos.shape = ', relative_pos.shape)
|
|
x = self.graph_conv(x, relative_pos)
|
|
# print('x.shape = ', x.shape)
|
|
x = self.fc2(x)
|
|
x = self.drop_path(x) + _tmp
|
|
return x
|