from typing import Optional import torch import torch.nn as nn from torch import Tensor class MultiHeadAttention(nn.Module): """ This layer applies a multi-head self- or cross-attention as described in `Attention is all you need `_ paper Args: embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` num_heads (int): Number of heads in multi-head attention attn_dropout (float): Attention dropout. Default: 0.0 bias (bool): Use bias or not. Default: ``True`` Shape: - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, and :math:`C_{in}` is input embedding dim - Output: same shape as the input """ def __init__( self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0, bias: bool = True, *args, **kwargs ) -> None: super().__init__() if embed_dim % num_heads != 0: raise ValueError( "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format( self.__class__.__name__, embed_dim, num_heads ) ) self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias) self.attn_dropout = nn.Dropout(p=attn_dropout) self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias) self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.softmax = nn.Softmax(dim=-1) self.num_heads = num_heads self.embed_dim = embed_dim def forward(self, x_q: Tensor) -> Tensor: # [N, P, C] b_sz, n_patches, in_channels = x_q.shape # self-attention # [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1) # [N, P, 3, h, c] -> [N, h, 3, P, C] qkv = qkv.transpose(1, 3).contiguous() # [N, h, 3, P, C] -> [N, h, P, C] x 3 query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] query = query * self.scaling # [N h, P, c] -> [N, h, c, P] key = key.transpose(-1, -2) # QK^T # [N, h, P, c] x [N, h, c, P] -> [N, h, P, P] attn = torch.matmul(query, key) attn = self.softmax(attn) attn = self.attn_dropout(attn) # weighted sum # [N, h, P, P] x [N, h, P, c] -> [N, h, P, c] out = torch.matmul(attn, value) # [N, h, P, c] -> [N, P, h, c] -> [N, P, C] out = out.transpose(1, 2).reshape(b_sz, n_patches, -1) out = self.out_proj(out) return out class TransformerEncoder(nn.Module): """ This class defines the pre-norm `Transformer encoder `_ Args: embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` ffn_latent_dim (int): Inner dimension of the FFN num_heads (int) : Number of heads in multi-head attention. Default: 8 attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0 dropout (float): Dropout rate. Default: 0.0 ffn_dropout (float): Dropout between FFN layers. Default: 0.0 Shape: - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, and :math:`C_{in}` is input embedding dim - Output: same shape as the input """ def __init__( self, embed_dim: int, ffn_latent_dim: int, num_heads: Optional[int] = 8, attn_dropout: Optional[float] = 0.0, dropout: Optional[float] = 0.0, ffn_dropout: Optional[float] = 0.0, *args, **kwargs ) -> None: super().__init__() attn_unit = MultiHeadAttention( embed_dim, num_heads, attn_dropout=attn_dropout, bias=True ) self.pre_norm_mha = nn.Sequential( nn.LayerNorm(embed_dim), attn_unit, nn.Dropout(p=dropout) ) self.pre_norm_ffn = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), nn.SiLU(), nn.Dropout(p=ffn_dropout), nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), nn.Dropout(p=dropout) ) self.embed_dim = embed_dim self.ffn_dim = ffn_latent_dim self.ffn_dropout = ffn_dropout self.std_dropout = dropout def forward(self, x: Tensor) -> Tensor: # multi-head attention res = x x = self.pre_norm_mha(x) x = x + res # feed forward network x = x + self.pre_norm_ffn(x) return x