154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
|
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BLOCK_N': 32}),
|
|
triton.Config({'BLOCK_N': 64}),
|
|
triton.Config({'BLOCK_N': 128}),
|
|
triton.Config({'BLOCK_N': 256}),
|
|
triton.Config({'BLOCK_N': 512}),
|
|
triton.Config({'BLOCK_N': 1024}),
|
|
],
|
|
key=['ncols'],
|
|
)
|
|
@triton.jit
|
|
def _swiglu_fwd_kernel(
|
|
X,
|
|
Y,
|
|
OUT,
|
|
stride_x_row, # how much to increase the pointer when moving by 1 row
|
|
stride_y_row,
|
|
stride_out_row,
|
|
ncols,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
# Map the program id to the row of X and Y it should compute.
|
|
row = tl.program_id(0)
|
|
start_col = tl.program_id(1) * BLOCK_N
|
|
X += row * stride_x_row
|
|
Y += row * stride_y_row
|
|
OUT += row * stride_out_row
|
|
cols = start_col + tl.arange(0, BLOCK_N)
|
|
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
|
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
|
out = x * tl.sigmoid(x) * y
|
|
tl.store(OUT + cols, out, mask=cols < ncols)
|
|
|
|
|
|
def _swiglu_fwd(xy, out=None):
|
|
if xy.stride(-1) != 1:
|
|
xy = xy.contiguous()
|
|
batch_shape = xy.shape[:-1]
|
|
xy = xy.reshape(-1, xy.shape[-1])
|
|
x, y = xy.chunk(2, dim=-1)
|
|
if out is None:
|
|
out = torch.empty_like(x)
|
|
else:
|
|
out = out.reshape(-1, out.shape[-1])
|
|
assert out.shape == x.shape
|
|
assert out.stride(-1) == 1
|
|
M, N = x.shape
|
|
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
|
with torch.cuda.device(x.device.index):
|
|
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
|
|
return out.reshape(*batch_shape, out.shape[-1])
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BLOCK_N': 32}),
|
|
triton.Config({'BLOCK_N': 64}),
|
|
triton.Config({'BLOCK_N': 128}),
|
|
triton.Config({'BLOCK_N': 256}),
|
|
triton.Config({'BLOCK_N': 512}),
|
|
triton.Config({'BLOCK_N': 1024}),
|
|
],
|
|
key=['ncols'],
|
|
)
|
|
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
|
|
@triton.jit
|
|
def _swiglu_bwd_kernel(
|
|
X,
|
|
Y,
|
|
DOUT,
|
|
OUT,
|
|
DX,
|
|
DY,
|
|
stride_x_row, # how much to increase the pointer when moving by 1 row
|
|
stride_y_row,
|
|
stride_dout_row,
|
|
stride_out_row,
|
|
stride_dx_row,
|
|
stride_dy_row,
|
|
ncols,
|
|
BLOCK_N: tl.constexpr,
|
|
RECOMPUTE_OUTPUT: tl.constexpr,
|
|
):
|
|
# Map the program id to the row of X and Y it should compute.
|
|
row = tl.program_id(0)
|
|
start_col = tl.program_id(1) * BLOCK_N
|
|
X += row * stride_x_row
|
|
Y += row * stride_y_row
|
|
DOUT += row * stride_dout_row
|
|
if RECOMPUTE_OUTPUT:
|
|
OUT += row * stride_out_row
|
|
DX += row * stride_dx_row
|
|
DY += row * stride_dy_row
|
|
cols = start_col + tl.arange(0, BLOCK_N)
|
|
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
|
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
|
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
|
x_sigmoid = tl.sigmoid(x)
|
|
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
|
|
dy = x * x_sigmoid * dout
|
|
tl.store(DX + cols, dx, mask=cols < ncols)
|
|
tl.store(DY + cols, dy, mask=cols < ncols)
|
|
if RECOMPUTE_OUTPUT:
|
|
out = x * x_sigmoid * y
|
|
tl.store(OUT + cols, out, mask=cols < ncols)
|
|
|
|
|
|
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
|
|
if xy.stride(-1) != 1:
|
|
xy = xy.contiguous()
|
|
if dout.stride(-1) != 1:
|
|
dout = dout.contiguous()
|
|
batch_shape = xy.shape[:-1]
|
|
xy = xy.reshape(-1, xy.shape[-1])
|
|
x, y = xy.chunk(2, dim=-1)
|
|
dout = dout.reshape(-1, dout.shape[-1])
|
|
assert dout.shape == x.shape
|
|
if dxy is None:
|
|
dxy = torch.empty_like(xy)
|
|
else:
|
|
dxy = dxy.reshape(-1, dxy.shape[-1])
|
|
assert dxy.shape == xy.shape
|
|
dx, dy = dxy.chunk(2, dim=-1)
|
|
assert dx.stride(-1) == 1
|
|
assert dy.stride(-1) == 1
|
|
if recompute_output:
|
|
if out is None:
|
|
out = torch.empty_like(x)
|
|
else:
|
|
out = out.reshape(-1, out.shape[-1])
|
|
assert out.shape == x.shape
|
|
assert out.stride(-1) == 1
|
|
M, N = x.shape
|
|
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
|
with torch.cuda.device(x.device.index):
|
|
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
|
|
x.stride(0), y.stride(0), dout.stride(0),
|
|
out.stride(0) if recompute_output else 0,
|
|
dx.stride(0), dy.stride(0),
|
|
N)
|
|
if not recompute_output:
|
|
return dxy.reshape(*batch_shape, dxy.shape[-1])
|
|
else:
|
|
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
|