200 lines
7.4 KiB
Python
200 lines
7.4 KiB
Python
|
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||
|
|
||
|
import math
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from einops import rearrange, repeat
|
||
|
|
||
|
try:
|
||
|
from causal_conv1d import causal_conv1d_fn
|
||
|
except ImportError:
|
||
|
causal_conv1d_fn = None
|
||
|
|
||
|
try:
|
||
|
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
||
|
except ImportError:
|
||
|
RMSNormGated, LayerNorm = None, None
|
||
|
|
||
|
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
||
|
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
||
|
|
||
|
|
||
|
class Mamba2Simple(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
d_model,
|
||
|
d_state=64,
|
||
|
d_conv=4,
|
||
|
conv_init=None,
|
||
|
expand=2,
|
||
|
headdim=128,
|
||
|
ngroups=1,
|
||
|
A_init_range=(1, 16),
|
||
|
dt_min=0.001,
|
||
|
dt_max=0.1,
|
||
|
dt_init_floor=1e-4,
|
||
|
dt_limit=(0.0, float("inf")),
|
||
|
learnable_init_states=False,
|
||
|
activation="swish",
|
||
|
bias=False,
|
||
|
conv_bias=True,
|
||
|
# Fused kernel and sharding options
|
||
|
chunk_size=256,
|
||
|
use_mem_eff_path=True,
|
||
|
layer_idx=None, # Absorb kwarg for general module
|
||
|
device=None,
|
||
|
dtype=None,
|
||
|
):
|
||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||
|
super().__init__()
|
||
|
self.d_model = d_model
|
||
|
self.d_state = d_state
|
||
|
self.d_conv = d_conv
|
||
|
self.conv_init = conv_init
|
||
|
self.expand = expand
|
||
|
self.d_inner = self.expand * self.d_model
|
||
|
self.headdim = headdim
|
||
|
self.ngroups = ngroups
|
||
|
assert self.d_inner % self.headdim == 0
|
||
|
self.nheads = self.d_inner // self.headdim
|
||
|
self.dt_limit = dt_limit
|
||
|
self.learnable_init_states = learnable_init_states
|
||
|
self.activation = activation
|
||
|
self.chunk_size = chunk_size
|
||
|
self.use_mem_eff_path = use_mem_eff_path
|
||
|
self.layer_idx = layer_idx
|
||
|
|
||
|
# Order: [z, x, B, C, dt]
|
||
|
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
||
|
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
||
|
|
||
|
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
||
|
self.conv1d = nn.Conv1d(
|
||
|
in_channels=conv_dim,
|
||
|
out_channels=conv_dim,
|
||
|
bias=conv_bias,
|
||
|
kernel_size=d_conv,
|
||
|
groups=conv_dim,
|
||
|
padding=d_conv - 1,
|
||
|
**factory_kwargs,
|
||
|
)
|
||
|
if self.conv_init is not None:
|
||
|
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
||
|
# self.conv1d.weight._no_weight_decay = True
|
||
|
|
||
|
if self.learnable_init_states:
|
||
|
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
|
||
|
self.init_states._no_weight_decay = True
|
||
|
|
||
|
self.act = nn.SiLU()
|
||
|
|
||
|
# Initialize log dt bias
|
||
|
dt = torch.exp(
|
||
|
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||
|
+ math.log(dt_min)
|
||
|
)
|
||
|
dt = torch.clamp(dt, min=dt_init_floor)
|
||
|
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||
|
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||
|
self.dt_bias = nn.Parameter(inv_dt)
|
||
|
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
||
|
# name.endswith("bias") in param_grouping.py
|
||
|
self.dt_bias._no_weight_decay = True
|
||
|
|
||
|
# A parameter
|
||
|
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
||
|
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
|
||
|
A_log = torch.log(A).to(dtype=dtype)
|
||
|
self.A_log = nn.Parameter(A_log)
|
||
|
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
||
|
self.A_log._no_weight_decay = True
|
||
|
|
||
|
# D "skip" parameter
|
||
|
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
||
|
self.D._no_weight_decay = True
|
||
|
|
||
|
# Extra normalization layer right before output projection
|
||
|
assert RMSNormGated is not None
|
||
|
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
|
||
|
|
||
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||
|
|
||
|
def forward(self, u, seq_idx=None):
|
||
|
"""
|
||
|
u: (B, L, D)
|
||
|
Returns: same shape as u
|
||
|
"""
|
||
|
batch, seqlen, dim = u.shape
|
||
|
|
||
|
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
||
|
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
||
|
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
|
||
|
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
||
|
|
||
|
if self.use_mem_eff_path:
|
||
|
# Fully fused path
|
||
|
out = mamba_split_conv1d_scan_combined(
|
||
|
zxbcdt,
|
||
|
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||
|
self.conv1d.bias,
|
||
|
self.dt_bias,
|
||
|
A,
|
||
|
D=self.D,
|
||
|
chunk_size=self.chunk_size,
|
||
|
seq_idx=seq_idx,
|
||
|
activation=self.activation,
|
||
|
rmsnorm_weight=self.norm.weight,
|
||
|
rmsnorm_eps=self.norm.eps,
|
||
|
outproj_weight=self.out_proj.weight,
|
||
|
outproj_bias=self.out_proj.bias,
|
||
|
headdim=self.headdim,
|
||
|
ngroups=self.ngroups,
|
||
|
norm_before_gate=False,
|
||
|
initial_states=initial_states,
|
||
|
**dt_limit_kwargs,
|
||
|
)
|
||
|
else:
|
||
|
z, xBC, dt = torch.split(
|
||
|
zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
|
||
|
)
|
||
|
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
||
|
assert self.activation in ["silu", "swish"]
|
||
|
|
||
|
# 1D Convolution
|
||
|
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
||
|
xBC = self.act(
|
||
|
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
||
|
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
||
|
else:
|
||
|
xBC = causal_conv1d_fn(
|
||
|
x=xBC.transpose(1, 2),
|
||
|
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||
|
bias=self.conv1d.bias,
|
||
|
activation=self.activation,
|
||
|
).transpose(1, 2)
|
||
|
|
||
|
# Split into 3 main branches: X, B, C
|
||
|
# These correspond to V, K, Q respectively in the SSM/attention duality
|
||
|
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
||
|
y = mamba_chunk_scan_combined(
|
||
|
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
||
|
dt,
|
||
|
A,
|
||
|
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
||
|
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
||
|
chunk_size=self.chunk_size,
|
||
|
D=self.D,
|
||
|
z=None,
|
||
|
seq_idx=seq_idx,
|
||
|
initial_states=initial_states,
|
||
|
**dt_limit_kwargs,
|
||
|
)
|
||
|
y = rearrange(y, "b l h p -> b l (h p)")
|
||
|
|
||
|
# Multiply "gate" branch and apply extra normalization layer
|
||
|
y = self.norm(y, z)
|
||
|
out = self.out_proj(y)
|
||
|
return out
|