156 lines
4.9 KiB
Python
156 lines
4.9 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
"""
|
|
This layer applies a multi-head self- or cross-attention as described in
|
|
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
|
|
|
Args:
|
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
|
num_heads (int): Number of heads in multi-head attention
|
|
attn_dropout (float): Attention dropout. Default: 0.0
|
|
bias (bool): Use bias or not. Default: ``True``
|
|
|
|
Shape:
|
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
|
and :math:`C_{in}` is input embedding dim
|
|
- Output: same shape as the input
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
attn_dropout: float = 0.0,
|
|
bias: bool = True,
|
|
*args,
|
|
**kwargs
|
|
) -> None:
|
|
super().__init__()
|
|
if embed_dim % num_heads != 0:
|
|
raise ValueError(
|
|
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
|
self.__class__.__name__, embed_dim, num_heads
|
|
)
|
|
)
|
|
|
|
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
|
|
|
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
|
self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
|
|
|
|
self.head_dim = embed_dim // num_heads
|
|
self.scaling = self.head_dim ** -0.5
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
self.num_heads = num_heads
|
|
self.embed_dim = embed_dim
|
|
|
|
def forward(self, x_q: Tensor) -> Tensor:
|
|
# [N, P, C]
|
|
b_sz, n_patches, in_channels = x_q.shape
|
|
|
|
# self-attention
|
|
# [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
|
|
qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
|
|
|
|
# [N, P, 3, h, c] -> [N, h, 3, P, C]
|
|
qkv = qkv.transpose(1, 3).contiguous()
|
|
|
|
# [N, h, 3, P, C] -> [N, h, P, C] x 3
|
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
|
|
|
query = query * self.scaling
|
|
|
|
# [N h, P, c] -> [N, h, c, P]
|
|
key = key.transpose(-1, -2)
|
|
|
|
# QK^T
|
|
# [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
|
|
attn = torch.matmul(query, key)
|
|
attn = self.softmax(attn)
|
|
attn = self.attn_dropout(attn)
|
|
|
|
# weighted sum
|
|
# [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
|
|
out = torch.matmul(attn, value)
|
|
|
|
# [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
|
|
out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
|
|
out = self.out_proj(out)
|
|
|
|
return out
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
"""
|
|
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
|
Args:
|
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
|
|
ffn_latent_dim (int): Inner dimension of the FFN
|
|
num_heads (int) : Number of heads in multi-head attention. Default: 8
|
|
attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
|
dropout (float): Dropout rate. Default: 0.0
|
|
ffn_dropout (float): Dropout between FFN layers. Default: 0.0
|
|
|
|
Shape:
|
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
|
and :math:`C_{in}` is input embedding dim
|
|
- Output: same shape as the input
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
ffn_latent_dim: int,
|
|
num_heads: Optional[int] = 8,
|
|
attn_dropout: Optional[float] = 0.0,
|
|
dropout: Optional[float] = 0.0,
|
|
ffn_dropout: Optional[float] = 0.0,
|
|
*args,
|
|
**kwargs
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
attn_unit = MultiHeadAttention(
|
|
embed_dim,
|
|
num_heads,
|
|
attn_dropout=attn_dropout,
|
|
bias=True
|
|
)
|
|
|
|
self.pre_norm_mha = nn.Sequential(
|
|
nn.LayerNorm(embed_dim),
|
|
attn_unit,
|
|
nn.Dropout(p=dropout)
|
|
)
|
|
|
|
self.pre_norm_ffn = nn.Sequential(
|
|
nn.LayerNorm(embed_dim),
|
|
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
|
nn.SiLU(),
|
|
nn.Dropout(p=ffn_dropout),
|
|
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
|
nn.Dropout(p=dropout)
|
|
)
|
|
self.embed_dim = embed_dim
|
|
self.ffn_dim = ffn_latent_dim
|
|
self.ffn_dropout = ffn_dropout
|
|
self.std_dropout = dropout
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# multi-head attention
|
|
res = x
|
|
x = self.pre_norm_mha(x)
|
|
x = x + res
|
|
|
|
# feed forward network
|
|
x = x + self.pre_norm_ffn(x)
|
|
return x
|