new code
This commit is contained in:
commit
9eea6c07af
8
Mamba/.idea/.gitignore
generated
vendored
Normal file
8
Mamba/.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
15
Mamba/.idea/Manba.iml
generated
Normal file
15
Mamba/.idea/Manba.iml
generated
Normal file
@ -0,0 +1,15 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="py.test" />
|
||||
</component>
|
||||
</module>
|
||||
15
Mamba/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
15
Mamba/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,15 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="N803" />
|
||||
<option value="N802" />
|
||||
<option value="N806" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
Mamba/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
Mamba/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
Mamba/.idea/misc.xml
generated
Normal file
7
Mamba/.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.6 (lxmert-master)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (lxmert-master)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
Mamba/.idea/modules.xml
generated
Normal file
8
Mamba/.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/Manba.iml" filepath="$PROJECT_DIR$/.idea/Manba.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
39
Mamba/mamba-130f-hf/config.json
Normal file
39
Mamba/mamba-130f-hf/config.json
Normal file
@ -0,0 +1,39 @@
|
||||
{
|
||||
"architectures": [
|
||||
"MambaForCausalLM"
|
||||
],
|
||||
"bos_token_id": 0,
|
||||
"conv_kernel": 4,
|
||||
"d_inner": 1536,
|
||||
"d_model": 768,
|
||||
"eos_token_id": 0,
|
||||
"expand": 2,
|
||||
"fused_add_norm": true,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.1,
|
||||
"intermediate_size": 1536,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "mamba",
|
||||
"n_layer": 24,
|
||||
"num_hidden_layers": 24,
|
||||
"pad_token_id": 0,
|
||||
"pad_vocab_size_multiple": 8,
|
||||
"rescale_prenorm_residual": false,
|
||||
"residual_in_fp32": true,
|
||||
"rms_norm": true,
|
||||
"ssm_cfg": {},
|
||||
"state_size": 16,
|
||||
"time_step_floor": 0.0001,
|
||||
"time_step_init_scheme": "random",
|
||||
"time_step_max": 0.1,
|
||||
"time_step_min": 0.001,
|
||||
"time_step_rank": 48,
|
||||
"time_step_scale": 1.0,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.39.0.dev0",
|
||||
"use_bias": false,
|
||||
"use_cache": true,
|
||||
"use_conv_bias": true,
|
||||
"vocab_size": 50280
|
||||
}
|
||||
7
Mamba/mamba-130f-hf/generation_config.json
Normal file
7
Mamba/mamba-130f-hf/generation_config.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 0,
|
||||
"pad_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"transformers_version": "4.39.0.dev0"
|
||||
}
|
||||
BIN
Mamba/mamba-130f-hf/model.safetensors
Normal file
BIN
Mamba/mamba-130f-hf/model.safetensors
Normal file
Binary file not shown.
100534
Mamba/mamba-130f-hf/tokenizer.json
Normal file
100534
Mamba/mamba-130f-hf/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
212
Mamba/mamba-130f-hf/tokenizer_config.json
Normal file
212
Mamba/mamba-130f-hf/tokenizer_config.json
Normal file
@ -0,0 +1,212 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|padding|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"50254": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50255": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50256": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50257": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50258": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50259": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50260": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50261": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50262": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50263": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50264": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50265": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50266": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50267": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50268": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50269": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50270": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50271": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50272": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50273": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50274": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50275": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"50276": {
|
||||
"content": " ",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"bos_token": "<|endoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "GPTNeoXTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
231
Mamba/mamba-main/.github/workflows/publish.yaml
vendored
Normal file
231
Mamba/mamba-main/.github/workflows/publish.yaml
vendored
Normal file
@ -0,0 +1,231 @@
|
||||
# This workflow will:
|
||||
# - Create a new Github release
|
||||
# - Build wheels for supported architectures
|
||||
# - Deploy the wheels to the Github release
|
||||
# - Release the static code to PyPi
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
name: Build wheels and deploy
|
||||
|
||||
on:
|
||||
create:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
jobs:
|
||||
|
||||
setup_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
shell: bash
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
|
||||
build_wheels:
|
||||
name: Build Wheel
|
||||
needs: setup_release
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
|
||||
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
|
||||
os: [ubuntu-20.04]
|
||||
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0', '2.4.0.dev20240420']
|
||||
cuda-version: ['11.8.0', '12.2.2']
|
||||
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
|
||||
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
|
||||
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
|
||||
# when building without C++11 ABI and using it on nvcr images.
|
||||
cxx11_abi: ['FALSE', 'TRUE']
|
||||
exclude:
|
||||
# Pytorch < 2.2 does not support Python 3.12
|
||||
- torch-version: '1.12.1'
|
||||
python-version: '3.12'
|
||||
- torch-version: '1.13.1'
|
||||
python-version: '3.12'
|
||||
- torch-version: '2.0.1'
|
||||
python-version: '3.12'
|
||||
- torch-version: '2.1.2'
|
||||
python-version: '3.12'
|
||||
# Pytorch <= 1.12 does not support Python 3.11
|
||||
- torch-version: '1.12.1'
|
||||
python-version: '3.11'
|
||||
# Pytorch >= 2.0 only supports Python >= 3.8
|
||||
- torch-version: '2.0.1'
|
||||
python-version: '3.7'
|
||||
- torch-version: '2.1.2'
|
||||
python-version: '3.7'
|
||||
- torch-version: '2.2.2'
|
||||
python-version: '3.7'
|
||||
- torch-version: '2.3.0'
|
||||
python-version: '3.7'
|
||||
- torch-version: '2.4.0.dev20240420'
|
||||
python-version: '3.7'
|
||||
# Pytorch <= 2.0 only supports CUDA <= 11.8
|
||||
- torch-version: '1.12.1'
|
||||
cuda-version: '12.2.2'
|
||||
- torch-version: '1.13.1'
|
||||
cuda-version: '12.2.2'
|
||||
- torch-version: '2.0.1'
|
||||
cuda-version: '12.2.2'
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set CUDA and PyTorch versions
|
||||
run: |
|
||||
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
|
||||
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
|
||||
|
||||
- name: Free up disk space
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
|
||||
# https://github.com/easimon/maximize-build-space/tree/test-report
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Set up swap space
|
||||
if: runner.os == 'Linux'
|
||||
uses: pierotofy/set-swap-space@v1.0
|
||||
with:
|
||||
swap-size-gb: 10
|
||||
|
||||
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||
if: ${{ matrix.cuda-version != 'cpu' }}
|
||||
uses: Jimver/cuda-toolkit@v0.2.14
|
||||
id: cuda-toolkit
|
||||
with:
|
||||
cuda: ${{ matrix.cuda-version }}
|
||||
linux-local-args: '["--toolkit"]'
|
||||
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
|
||||
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
|
||||
method: 'network'
|
||||
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
|
||||
# not just nvcc
|
||||
# sub-packages: '["nvcc"]'
|
||||
|
||||
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
# If we don't install before installing Pytorch, we get error for torch 2.0.1
|
||||
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
|
||||
pip install lit
|
||||
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
|
||||
pip install setuptools
|
||||
# We want to figure out the CUDA version to download pytorch
|
||||
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
|
||||
# This code is ugly, maybe there's a better way to do this.
|
||||
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
|
||||
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
|
||||
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
|
||||
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
|
||||
)
|
||||
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
|
||||
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
|
||||
else
|
||||
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
|
||||
fi
|
||||
nvcc --version
|
||||
python --version
|
||||
python -c "import torch; print('PyTorch:', torch.__version__)"
|
||||
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
||||
shell:
|
||||
bash
|
||||
|
||||
- name: Build wheel
|
||||
run: |
|
||||
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
|
||||
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
|
||||
# However this still fails so I'm using a newer version of setuptools
|
||||
pip install setuptools==68.0.0
|
||||
pip install ninja packaging wheel
|
||||
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
# Limit MAX_JOBS otherwise the github runner goes OOM
|
||||
MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
|
||||
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
|
||||
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
||||
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
||||
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||
|
||||
- name: Log Built Wheels
|
||||
run: |
|
||||
ls dist
|
||||
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
|
||||
- name: Get Release with tag
|
||||
id: get_current_release
|
||||
uses: joutvhu/get-release@v1
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload Release Asset
|
||||
id: upload_release_asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
||||
asset_path: ./dist/${{env.wheel_name}}
|
||||
asset_name: ${{env.wheel_name}}
|
||||
asset_content_type: application/*
|
||||
|
||||
publish_package:
|
||||
name: Publish package
|
||||
needs: [build_wheels]
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ninja packaging setuptools wheel twine
|
||||
# We don't want to download anything CUDA-related here
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Build core package
|
||||
env:
|
||||
MAMBA_SKIP_CUDA_BUILD: "TRUE"
|
||||
run: |
|
||||
python setup.py sdist --dist-dir=dist
|
||||
|
||||
- name: Deploy
|
||||
env:
|
||||
TWINE_USERNAME: "__token__"
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
python -m twine upload dist/*
|
||||
6
Mamba/mamba-main/.gitignore
vendored
Normal file
6
Mamba/mamba-main/.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
*__pycache__/
|
||||
*.egg-info/
|
||||
build/
|
||||
**.so
|
||||
*.hip
|
||||
*_hip.*
|
||||
3
Mamba/mamba-main/.gitmodules
vendored
Normal file
3
Mamba/mamba-main/.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "3rdparty/lm-evaluation-harness"]
|
||||
path = 3rdparty/lm-evaluation-harness
|
||||
url = https://github.com/EleutherAI/lm-evaluation-harness/
|
||||
8
Mamba/mamba-main/.idea/.gitignore
generated
vendored
Normal file
8
Mamba/mamba-main/.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
15
Mamba/mamba-main/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
15
Mamba/mamba-main/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,15 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="N803" />
|
||||
<option value="N802" />
|
||||
<option value="N806" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
Mamba/mamba-main/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
Mamba/mamba-main/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
17
Mamba/mamba-main/.idea/mamba-main.iml
generated
Normal file
17
Mamba/mamba-main/.idea/mamba-main.iml
generated
Normal file
@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="py.test" />
|
||||
</component>
|
||||
</module>
|
||||
4
Mamba/mamba-main/.idea/misc.xml
generated
Normal file
4
Mamba/mamba-main/.idea/misc.xml
generated
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (mamba-main)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
Mamba/mamba-main/.idea/modules.xml
generated
Normal file
8
Mamba/mamba-main/.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/mamba-main.iml" filepath="$PROJECT_DIR$/.idea/mamba-main.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
2
Mamba/mamba-main/AUTHORS
Normal file
2
Mamba/mamba-main/AUTHORS
Normal file
@ -0,0 +1,2 @@
|
||||
Tri Dao, tri@tridao.me
|
||||
Albert Gu, agu@andrew.cmu.edu
|
||||
201
Mamba/mamba-main/LICENSE
Normal file
201
Mamba/mamba-main/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2023 Tri Dao, Albert Gu
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
241
Mamba/mamba-main/README.md
Normal file
241
Mamba/mamba-main/README.md
Normal file
@ -0,0 +1,241 @@
|
||||
# Mamba
|
||||
|
||||

|
||||
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
|
||||
> Albert Gu*, Tri Dao*\
|
||||
> Paper: https://arxiv.org/abs/2312.00752
|
||||
|
||||

|
||||
> **Transformers are SSMs: Generalized Models and Efficient Algorithms**\
|
||||
> **Through Structured State Space Duality**\
|
||||
> Tri Dao*, Albert Gu*\
|
||||
> Paper: https://arxiv.org/abs/2405.21060
|
||||
|
||||
## About
|
||||
|
||||
Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
|
||||
It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
|
||||
with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
|
||||
|
||||
## Installation
|
||||
|
||||
- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
|
||||
- `pip install mamba-ssm`: the core Mamba package.
|
||||
|
||||
It can also be built from source with `pip install .` from this repository.
|
||||
|
||||
If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
|
||||
|
||||
Other requirements:
|
||||
- Linux
|
||||
- NVIDIA GPU
|
||||
- PyTorch 1.12+
|
||||
- CUDA 11.6+
|
||||
|
||||
For AMD cards, see additional prerequisites below.
|
||||
|
||||
## Usage
|
||||
|
||||
We expose several levels of interface with the Mamba model.
|
||||
|
||||
### Selective SSM
|
||||
|
||||
Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
|
||||
|
||||
Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
|
||||
|
||||
### Mamba Block
|
||||
|
||||
The main module of this repository is the Mamba architecture block wrapping the selective SSM.
|
||||
|
||||
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
|
||||
|
||||
Usage:
|
||||
``` python
|
||||
import torch
|
||||
from mamba_ssm import Mamba
|
||||
|
||||
batch, length, dim = 2, 64, 16
|
||||
x = torch.randn(batch, length, dim).to("cuda")
|
||||
model = Mamba(
|
||||
# This module uses roughly 3 * expand * d_model^2 parameters
|
||||
d_model=dim, # Model dimension d_model
|
||||
d_state=16, # SSM state expansion factor
|
||||
d_conv=4, # Local convolution width
|
||||
expand=2, # Block expansion factor
|
||||
).to("cuda")
|
||||
y = model(x)
|
||||
assert y.shape == x.shape
|
||||
```
|
||||
|
||||
### Mamba-2
|
||||
|
||||
The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).
|
||||
|
||||
A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
|
||||
|
||||
The usage is similar to Mamba(-1):
|
||||
``` python
|
||||
from mamba_ssm import Mamba2
|
||||
model = Mamba2(
|
||||
# This module uses roughly 3 * expand * d_model^2 parameters
|
||||
d_model=dim, # Model dimension d_model
|
||||
d_state=64, # SSM state expansion factor, typically 64 or 128
|
||||
d_conv=4, # Local convolution width
|
||||
expand=2, # Block expansion factor
|
||||
).to("cuda")
|
||||
y = model(x)
|
||||
assert y.shape == x.shape
|
||||
```
|
||||
|
||||
#### SSD
|
||||
|
||||
A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions
|
||||
is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).
|
||||
|
||||
### Mamba Language Model
|
||||
|
||||
Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
|
||||
|
||||
Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
|
||||
|
||||
This is an example of how to integrate Mamba into an end-to-end neural network.
|
||||
This example is used in the generation scripts below.
|
||||
|
||||
|
||||
## Pretrained Models
|
||||
|
||||
Pretrained models are uploaded to
|
||||
[Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
|
||||
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
|
||||
`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
|
||||
(trained on 600B tokens on the SlimPajama dataset).
|
||||
|
||||
|
||||
The models will be autodownloaded by the generation script below.
|
||||
|
||||
These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
|
||||
|
||||
| Parameters | Layers | Model dim. |
|
||||
|------------|--------|------------|
|
||||
| 130M | 24 | 768 |
|
||||
| 370M | 48 | 1024 |
|
||||
| 790M | 48 | 1536 |
|
||||
| 1.4B | 48 | 2048 |
|
||||
| 2.8B | 64 | 2560 |
|
||||
|
||||
(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
|
||||
|
||||
Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
|
||||
Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
|
||||
|
||||
|
||||
## Evaluations
|
||||
|
||||
To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
|
||||
we use the
|
||||
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
|
||||
library.
|
||||
|
||||
1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
|
||||
2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
|
||||
``` sh
|
||||
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
||||
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
|
||||
```
|
||||
|
||||
To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
|
||||
``` sh
|
||||
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
|
||||
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
|
||||
```
|
||||
|
||||
To run evaluations on Mamba-2 models, simply replace the model names:
|
||||
``` sh
|
||||
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
||||
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
||||
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
|
||||
```
|
||||
|
||||
Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
|
||||
|
||||
## Inference
|
||||
|
||||
The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
|
||||
1. autoloads a model from the Hugging Face Hub,
|
||||
2. generates completions of a user-specified prompt,
|
||||
3. benchmarks the inference speed of this generation.
|
||||
|
||||
Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
|
||||
|
||||
### Examples
|
||||
|
||||
To test generation latency (e.g. batch size = 1) with different sampling strategies:
|
||||
|
||||
``` sh
|
||||
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
||||
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
||||
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
|
||||
```
|
||||
|
||||
To test generation throughput with random prompts (e.g. large batch size):
|
||||
``` sh
|
||||
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
|
||||
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
|
||||
```
|
||||
|
||||
With Mamba-2, you just need to change the model name:
|
||||
``` sh
|
||||
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
|
||||
```
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Precision
|
||||
Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
|
||||
On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
|
||||
|
||||
We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
|
||||
as a first step please try a framework storing parameters in fp32 (such as AMP).
|
||||
|
||||
### Initialization
|
||||
Some parts of the model have initializations inherited from prior work on S4 models.
|
||||
For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
|
||||
However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
|
||||
If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
|
||||
that is specific to the training framework.
|
||||
|
||||
## Additional Prerequisites for AMD cards
|
||||
|
||||
### Patching ROCm
|
||||
|
||||
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
||||
|
||||
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
||||
|
||||
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
||||
```bash
|
||||
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this codebase, or otherwise find our work valuable, please cite Mamba:
|
||||
```
|
||||
@article{mamba,
|
||||
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
|
||||
author={Gu, Albert and Dao, Tri},
|
||||
journal={arXiv preprint arXiv:2312.00752},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@inproceedings{mamba2,
|
||||
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
|
||||
author={Dao, Tri and Gu, Albert},
|
||||
booktitle={International Conference on Machine Learning (ICML)},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
```
|
||||
BIN
Mamba/mamba-main/assets/selection.png
Normal file
BIN
Mamba/mamba-main/assets/selection.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 799 KiB |
BIN
Mamba/mamba-main/assets/ssd_algorithm.png
Normal file
BIN
Mamba/mamba-main/assets/ssd_algorithm.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 MiB |
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generation benchmarking")
|
||||
parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
|
||||
parser.add_argument("--prompt", type=str, default=None)
|
||||
parser.add_argument("--promptlen", type=int, default=100)
|
||||
parser.add_argument("--genlen", type=int, default=100)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--topk", type=int, default=1)
|
||||
parser.add_argument("--topp", type=float, default=1.0)
|
||||
parser.add_argument("--minp", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--batch", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
repeats = 3
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
|
||||
print(f"Loading model {args.model_name}")
|
||||
is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp")
|
||||
if is_mamba:
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
|
||||
model.eval()
|
||||
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
if args.prompt is None:
|
||||
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
|
||||
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
|
||||
else:
|
||||
tokens = tokenizer(args.prompt, return_tensors="pt")
|
||||
input_ids = tokens.input_ids.to(device=device)
|
||||
attn_mask = tokens.attention_mask.to(device=device)
|
||||
max_length = input_ids.shape[1] + args.genlen
|
||||
|
||||
if is_mamba:
|
||||
fn = lambda: model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
enable_timing=False,
|
||||
temperature=args.temperature,
|
||||
top_k=args.topk,
|
||||
top_p=args.topp,
|
||||
min_p=args.minp,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
else:
|
||||
fn = lambda: model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
max_length=max_length,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
do_sample=True,
|
||||
temperature=args.temperature,
|
||||
top_k=args.topk,
|
||||
top_p=args.topp,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
out = fn()
|
||||
if args.prompt is not None:
|
||||
print(tokenizer.batch_decode(out.sequences.tolist()))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeats):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
|
||||
print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
|
||||
415
Mamba/mamba-main/csrc/selective_scan/reverse_scan.cuh
Normal file
415
Mamba/mamba-main/csrc/selective_scan/reverse_scan.cuh
Normal file
@ -0,0 +1,415 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/config.cuh>
|
||||
|
||||
#include <cub/util_ptx.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/block/block_raking_layout.cuh>
|
||||
// #include <cub/detail/uninitialized_copy.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
#include "uninitialized_copy.cuh"
|
||||
|
||||
/**
|
||||
* Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
|
||||
*/
|
||||
template <
|
||||
int LENGTH,
|
||||
typename T,
|
||||
typename ReductionOp>
|
||||
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
|
||||
static_assert(LENGTH > 0);
|
||||
T retval = input[LENGTH - 1];
|
||||
#pragma unroll
|
||||
for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
|
||||
return retval;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
||||
*/
|
||||
template <
|
||||
int LENGTH,
|
||||
typename T,
|
||||
typename ScanOp>
|
||||
__device__ __forceinline__ T ThreadReverseScanInclusive(
|
||||
const T (&input)[LENGTH],
|
||||
T (&output)[LENGTH],
|
||||
ScanOp scan_op,
|
||||
const T postfix)
|
||||
{
|
||||
T inclusive = postfix;
|
||||
#pragma unroll
|
||||
for (int i = LENGTH - 1; i >= 0; --i) {
|
||||
inclusive = scan_op(inclusive, input[i]);
|
||||
output[i] = inclusive;
|
||||
}
|
||||
return inclusive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
||||
*/
|
||||
template <
|
||||
int LENGTH,
|
||||
typename T,
|
||||
typename ScanOp>
|
||||
__device__ __forceinline__ T ThreadReverseScanExclusive(
|
||||
const T (&input)[LENGTH],
|
||||
T (&output)[LENGTH],
|
||||
ScanOp scan_op,
|
||||
const T postfix)
|
||||
{
|
||||
// Careful, output maybe be aliased to input
|
||||
T exclusive = postfix;
|
||||
T inclusive;
|
||||
#pragma unroll
|
||||
for (int i = LENGTH - 1; i >= 0; --i) {
|
||||
inclusive = scan_op(exclusive, input[i]);
|
||||
output[i] = exclusive;
|
||||
exclusive = inclusive;
|
||||
}
|
||||
return inclusive;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
|
||||
*
|
||||
* LOGICAL_WARP_THREADS must be a power-of-two
|
||||
*/
|
||||
template <
|
||||
typename T, ///< Data type being scanned
|
||||
int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
|
||||
>
|
||||
struct WarpReverseScan {
|
||||
//---------------------------------------------------------------------
|
||||
// Constants and type definitions
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/// Whether the logical warp size and the PTX warp size coincide
|
||||
|
||||
// In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
|
||||
// While in cub, it's defined as a macro that takes a redundant unused argument.
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_THREADS CUB_WARP_THREADS(0)
|
||||
#else
|
||||
#define WARP_THREADS HIPCUB_WARP_THREADS
|
||||
#endif
|
||||
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
|
||||
/// The number of warp scan steps
|
||||
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
|
||||
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Thread fields
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/// Lane index in logical warp
|
||||
unsigned int lane_id;
|
||||
|
||||
/// Logical warp index in 32-thread physical warp
|
||||
unsigned int warp_id;
|
||||
|
||||
/// 32-thread physical warp member mask of logical warp
|
||||
unsigned int member_mask;
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Construction
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
explicit __device__ __forceinline__
|
||||
WarpReverseScan()
|
||||
: lane_id(cub::LaneId())
|
||||
, warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
|
||||
, member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
|
||||
{
|
||||
if (!IS_ARCH_WARP) {
|
||||
lane_id = lane_id % LOGICAL_WARP_THREADS;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Broadcast
|
||||
__device__ __forceinline__ T Broadcast(
|
||||
T input, ///< [in] The value to broadcast
|
||||
int src_lane) ///< [in] Which warp lane is to do the broadcasting
|
||||
{
|
||||
return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
|
||||
}
|
||||
|
||||
|
||||
/// Inclusive scan
|
||||
template <typename ScanOpT>
|
||||
__device__ __forceinline__ void InclusiveReverseScan(
|
||||
T input, ///< [in] Calling thread's input item.
|
||||
T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
||||
ScanOpT scan_op) ///< [in] Binary scan operator
|
||||
{
|
||||
inclusive_output = input;
|
||||
#pragma unroll
|
||||
for (int STEP = 0; STEP < STEPS; STEP++) {
|
||||
int offset = 1 << STEP;
|
||||
T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
||||
inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
|
||||
);
|
||||
// Perform scan op if from a valid peer
|
||||
inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
|
||||
? inclusive_output : scan_op(temp, inclusive_output);
|
||||
}
|
||||
}
|
||||
|
||||
/// Exclusive scan
|
||||
// Get exclusive from inclusive
|
||||
template <typename ScanOpT>
|
||||
__device__ __forceinline__ void ExclusiveReverseScan(
|
||||
T input, ///< [in] Calling thread's input item.
|
||||
T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
||||
ScanOpT scan_op, ///< [in] Binary scan operator
|
||||
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
|
||||
{
|
||||
T inclusive_output;
|
||||
InclusiveReverseScan(input, inclusive_output, scan_op);
|
||||
warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
|
||||
// initial value unknown
|
||||
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
||||
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
|
||||
*/
|
||||
template <typename ScanOpT>
|
||||
__device__ __forceinline__ void ReverseScan(
|
||||
T input, ///< [in] Calling thread's input item.
|
||||
T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
|
||||
T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
|
||||
ScanOpT scan_op) ///< [in] Binary scan operator
|
||||
{
|
||||
InclusiveReverseScan(input, inclusive_output, scan_op);
|
||||
// initial value unknown
|
||||
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
||||
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
|
||||
*/
|
||||
template <
|
||||
typename T, ///< Data type being scanned
|
||||
int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
|
||||
bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
|
||||
>
|
||||
struct BlockReverseScan {
|
||||
//---------------------------------------------------------------------
|
||||
// Types and constants
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/// Constants
|
||||
/// The thread block size in threads
|
||||
static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
|
||||
|
||||
/// Layout type for padded thread block raking grid
|
||||
using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
|
||||
// The number of reduction elements is not a multiple of the number of raking threads for now
|
||||
static_assert(BlockRakingLayout::UNGUARDED);
|
||||
|
||||
/// Number of raking threads
|
||||
static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
|
||||
/// Number of raking elements per warp synchronous raking thread
|
||||
static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
|
||||
/// Cooperative work can be entirely warp synchronous
|
||||
static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
|
||||
|
||||
/// WarpReverseScan utility type
|
||||
using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
|
||||
|
||||
/// Shared memory storage layout type
|
||||
struct _TempStorage {
|
||||
typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
|
||||
};
|
||||
|
||||
|
||||
/// Alias wrapper allowing storage to be unioned
|
||||
struct TempStorage : cub::Uninitialized<_TempStorage> {};
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Per-thread fields
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
// Thread fields
|
||||
_TempStorage &temp_storage;
|
||||
unsigned int linear_tid;
|
||||
T cached_segment[SEGMENT_LENGTH];
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Utility methods
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/// Performs upsweep raking reduction, returning the aggregate
|
||||
template <typename ScanOp>
|
||||
__device__ __forceinline__ T Upsweep(ScanOp scan_op) {
|
||||
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
||||
// Read data into registers
|
||||
#pragma unroll
|
||||
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
||||
T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
|
||||
#pragma unroll
|
||||
for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
|
||||
raking_partial = scan_op(raking_partial, cached_segment[i]);
|
||||
}
|
||||
return raking_partial;
|
||||
}
|
||||
|
||||
|
||||
/// Performs exclusive downsweep raking scan
|
||||
template <typename ScanOp>
|
||||
__device__ __forceinline__ void ExclusiveDownsweep(
|
||||
ScanOp scan_op,
|
||||
T raking_partial)
|
||||
{
|
||||
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
||||
// Read data back into registers
|
||||
if (!MEMOIZE) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
||||
}
|
||||
ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
|
||||
// Write data back to smem
|
||||
#pragma unroll
|
||||
for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
|
||||
}
|
||||
|
||||
|
||||
//---------------------------------------------------------------------
|
||||
// Constructors
|
||||
//---------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
__device__ __forceinline__ BlockReverseScan(
|
||||
TempStorage &temp_storage)
|
||||
:
|
||||
temp_storage(temp_storage.Alias()),
|
||||
linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
|
||||
{}
|
||||
|
||||
|
||||
/// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
||||
template <
|
||||
typename ScanOp,
|
||||
typename BlockPostfixCallbackOp>
|
||||
__device__ __forceinline__ void ExclusiveReverseScan(
|
||||
T input, ///< [in] Calling thread's input item
|
||||
T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
|
||||
ScanOp scan_op, ///< [in] Binary scan operator
|
||||
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
|
||||
{
|
||||
if (WARP_SYNCHRONOUS) {
|
||||
// Short-circuit directly to warp-synchronous scan
|
||||
T block_aggregate;
|
||||
WarpReverseScan warp_scan;
|
||||
warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
|
||||
// Obtain warp-wide postfix in lane0, then broadcast to other lanes
|
||||
T block_postfix = block_postfix_callback_op(block_aggregate);
|
||||
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
||||
exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
|
||||
} else {
|
||||
// Place thread partial into shared memory raking grid
|
||||
T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
|
||||
detail::uninitialized_copy(placement_ptr, input);
|
||||
cub::CTA_SYNC();
|
||||
// Reduce parallelism down to just raking threads
|
||||
if (linear_tid < RAKING_THREADS) {
|
||||
WarpReverseScan warp_scan;
|
||||
// Raking upsweep reduction across shared partials
|
||||
T upsweep_partial = Upsweep(scan_op);
|
||||
// Warp-synchronous scan
|
||||
T exclusive_partial, block_aggregate;
|
||||
warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
|
||||
// Obtain block-wide postfix in lane0, then broadcast to other lanes
|
||||
T block_postfix = block_postfix_callback_op(block_aggregate);
|
||||
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
||||
// Update postfix with warpscan exclusive partial
|
||||
T downsweep_postfix = linear_tid == RAKING_THREADS - 1
|
||||
? block_postfix : scan_op(block_postfix, exclusive_partial);
|
||||
// Exclusive raking downsweep scan
|
||||
ExclusiveDownsweep(scan_op, downsweep_postfix);
|
||||
}
|
||||
cub::CTA_SYNC();
|
||||
// Grab thread postfix from shared memory
|
||||
exclusive_output = *placement_ptr;
|
||||
|
||||
// // Compute warp scan in each warp.
|
||||
// // The exclusive output from the last lane in each warp is invalid.
|
||||
// T inclusive_output;
|
||||
// WarpReverseScan warp_scan;
|
||||
// warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
|
||||
|
||||
// // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
|
||||
// T block_aggregate;
|
||||
// T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
|
||||
|
||||
// // Apply warp postfix to our lane's partial
|
||||
// if (warp_id != 0) {
|
||||
// exclusive_output = scan_op(warp_postfix, exclusive_output);
|
||||
// if (lane_id == 0) { exclusive_output = warp_postfix; }
|
||||
// }
|
||||
|
||||
// // Use the first warp to determine the thread block postfix, returning the result in lane0
|
||||
// if (warp_id == 0) {
|
||||
// T block_postfix = block_postfix_callback_op(block_aggregate);
|
||||
// if (lane_id == 0) {
|
||||
// // Share the postfix with all threads
|
||||
// detail::uninitialized_copy(&temp_storage.block_postfix,
|
||||
// block_postfix);
|
||||
|
||||
// exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
|
||||
// }
|
||||
// }
|
||||
|
||||
// cub::CTA_SYNC();
|
||||
|
||||
// // Incorporate thread block postfix into outputs
|
||||
// T block_postfix = temp_storage.block_postfix;
|
||||
// if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
||||
*/
|
||||
template <
|
||||
int ITEMS_PER_THREAD,
|
||||
typename ScanOp,
|
||||
typename BlockPostfixCallbackOp>
|
||||
__device__ __forceinline__ void InclusiveReverseScan(
|
||||
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
|
||||
T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
|
||||
ScanOp scan_op, ///< [in] Binary scan functor
|
||||
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
|
||||
{
|
||||
// Reduce consecutive thread items in registers
|
||||
T thread_postfix = ThreadReverseReduce(input, scan_op);
|
||||
// Exclusive thread block-scan
|
||||
ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
|
||||
// Inclusive scan in registers with postfix as seed
|
||||
ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
|
||||
}
|
||||
|
||||
};
|
||||
497
Mamba/mamba-main/csrc/selective_scan/selective_scan.cpp
Normal file
497
Mamba/mamba-main/csrc/selective_scan/selective_scan.cpp
Normal file
@ -0,0 +1,497 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
#include "selective_scan.h"
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
||||
if (WTYPE == at::ScalarType::Half) { \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (WTYPE == at::ScalarType::Float) { \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
|
||||
if (WTYPE == at::ScalarType::Float) { \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (WTYPE == at::ScalarType::ComplexFloat) { \
|
||||
using weight_t = c10::complex<float>; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template <typename input_t, typename weight_t>
|
||||
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const at::Tensor u,
|
||||
const at::Tensor delta,
|
||||
const at::Tensor A,
|
||||
const at::Tensor B,
|
||||
const at::Tensor C,
|
||||
const at::Tensor out,
|
||||
const at::Tensor z,
|
||||
const at::Tensor out_z,
|
||||
void* D_ptr,
|
||||
void* delta_bias_ptr,
|
||||
void* x_ptr,
|
||||
bool has_z,
|
||||
bool delta_softplus) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
params.is_variable_B = is_variable_B;
|
||||
params.is_variable_C = is_variable_C;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.u_ptr = u.data_ptr();
|
||||
params.delta_ptr = delta.data_ptr();
|
||||
params.A_ptr = A.data_ptr();
|
||||
params.B_ptr = B.data_ptr();
|
||||
params.C_ptr = C.data_ptr();
|
||||
params.D_ptr = D_ptr;
|
||||
params.delta_bias_ptr = delta_bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
params.x_ptr = x_ptr;
|
||||
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
||||
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
params.A_dstate_stride = A.stride(1);
|
||||
if (!is_variable_B) {
|
||||
params.B_d_stride = B.stride(0);
|
||||
} else {
|
||||
params.B_batch_stride = B.stride(0);
|
||||
params.B_group_stride = B.stride(1);
|
||||
}
|
||||
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
||||
if (!is_variable_C) {
|
||||
params.C_d_stride = C.stride(0);
|
||||
} else {
|
||||
params.C_batch_stride = C.stride(0);
|
||||
params.C_group_stride = C.stride(1);
|
||||
}
|
||||
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
||||
params.u_batch_stride = u.stride(0);
|
||||
params.u_d_stride = u.stride(1);
|
||||
params.delta_batch_stride = delta.stride(0);
|
||||
params.delta_d_stride = delta.stride(1);
|
||||
if (has_z) {
|
||||
params.z_batch_stride = z.stride(0);
|
||||
params.z_d_stride = z.stride(1);
|
||||
params.out_z_batch_stride = out_z.stride(0);
|
||||
params.out_z_d_stride = out_z.stride(1);
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
}
|
||||
|
||||
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const at::Tensor u,
|
||||
const at::Tensor delta,
|
||||
const at::Tensor A,
|
||||
const at::Tensor B,
|
||||
const at::Tensor C,
|
||||
const at::Tensor z,
|
||||
const at::Tensor out,
|
||||
const at::Tensor out_z,
|
||||
void* D_ptr,
|
||||
void* delta_bias_ptr,
|
||||
void* x_ptr,
|
||||
const at::Tensor dout,
|
||||
const at::Tensor du,
|
||||
const at::Tensor ddelta,
|
||||
const at::Tensor dA,
|
||||
const at::Tensor dB,
|
||||
const at::Tensor dC,
|
||||
const at::Tensor dz,
|
||||
void* dD_ptr,
|
||||
void* ddelta_bias_ptr,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
bool recompute_out_z) {
|
||||
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
||||
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, has_z ? out : dout,
|
||||
has_z ? z : dout,
|
||||
// If not recompute_out_z, pass dout instead of out_z.
|
||||
// This won't be used by the bwd kernel
|
||||
recompute_out_z ? out_z : dout,
|
||||
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
|
||||
if (!recompute_out_z) { params.out_z_ptr = nullptr; }
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.dout_ptr = dout.data_ptr();
|
||||
params.du_ptr = du.data_ptr();
|
||||
params.dA_ptr = dA.data_ptr();
|
||||
params.dB_ptr = dB.data_ptr();
|
||||
params.dC_ptr = dC.data_ptr();
|
||||
params.dD_ptr = dD_ptr;
|
||||
params.ddelta_ptr = ddelta.data_ptr();
|
||||
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
||||
params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
|
||||
// All stride are in elements, not bytes.
|
||||
params.dout_batch_stride = dout.stride(0);
|
||||
params.dout_d_stride = dout.stride(1);
|
||||
params.dA_d_stride = dA.stride(0);
|
||||
params.dA_dstate_stride = dA.stride(1);
|
||||
if (!is_variable_B) {
|
||||
params.dB_d_stride = dB.stride(0);
|
||||
} else {
|
||||
params.dB_batch_stride = dB.stride(0);
|
||||
params.dB_group_stride = dB.stride(1);
|
||||
}
|
||||
params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
|
||||
if (!is_variable_C) {
|
||||
params.dC_d_stride = dC.stride(0);
|
||||
} else {
|
||||
params.dC_batch_stride = dC.stride(0);
|
||||
params.dC_group_stride = dC.stride(1);
|
||||
}
|
||||
params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
|
||||
params.du_batch_stride = du.stride(0);
|
||||
params.du_d_stride = du.stride(1);
|
||||
params.ddelta_batch_stride = ddelta.stride(0);
|
||||
params.ddelta_d_stride = ddelta.stride(1);
|
||||
if (has_z) {
|
||||
params.dz_batch_stride = dz.stride(0);
|
||||
params.dz_d_stride = dz.stride(1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
||||
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
||||
const c10::optional<at::Tensor> &D_,
|
||||
const c10::optional<at::Tensor> &z_,
|
||||
const c10::optional<at::Tensor> &delta_bias_,
|
||||
bool delta_softplus) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
if (!is_variable_B) {
|
||||
CHECK_SHAPE(B, dim, dstate);
|
||||
} else {
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
}
|
||||
if (!is_variable_C) {
|
||||
CHECK_SHAPE(C, dim, dstate);
|
||||
} else {
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
}
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
out_z = torch::empty_like(z);
|
||||
}
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = torch::empty_like(delta);
|
||||
at::Tensor x;
|
||||
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||
x.data_ptr(),
|
||||
has_z,
|
||||
delta_softplus);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
});
|
||||
std::vector<at::Tensor> result = {out, x};
|
||||
if (has_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
||||
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
||||
const c10::optional<at::Tensor> &D_,
|
||||
const c10::optional<at::Tensor> &z_,
|
||||
const c10::optional<at::Tensor> &delta_bias_,
|
||||
const at::Tensor &dout,
|
||||
const c10::optional<at::Tensor> &x_,
|
||||
const c10::optional<at::Tensor> &out_,
|
||||
c10::optional<at::Tensor> &dz_,
|
||||
bool delta_softplus,
|
||||
bool recompute_out_z) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
TORCH_CHECK(dout.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
TORCH_CHECK(dout.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
if (!is_variable_B) {
|
||||
CHECK_SHAPE(B, dim, dstate);
|
||||
} else {
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
}
|
||||
if (!is_variable_C) {
|
||||
CHECK_SHAPE(C, dim, dstate);
|
||||
} else {
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
}
|
||||
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor z, out, dz, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
|
||||
TORCH_CHECK(out_.has_value());
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.scalar_type() == input_type);
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
|
||||
CHECK_SHAPE(out, batch_size, dim, seqlen);
|
||||
|
||||
if (dz_.has_value()) {
|
||||
dz = dz_.value();
|
||||
TORCH_CHECK(dz.scalar_type() == input_type);
|
||||
TORCH_CHECK(dz.is_cuda());
|
||||
TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
|
||||
CHECK_SHAPE(dz, batch_size, dim, seqlen);
|
||||
} else {
|
||||
dz = torch::empty_like(z);
|
||||
}
|
||||
if (recompute_out_z) {
|
||||
out_z = torch::empty_like(out);
|
||||
}
|
||||
}
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
||||
if (x_.has_value()) {
|
||||
auto x = x_.value();
|
||||
TORCH_CHECK(x.scalar_type() == weight_type);
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(x.is_contiguous());
|
||||
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
||||
}
|
||||
|
||||
at::Tensor du = torch::empty_like(u);
|
||||
at::Tensor ddelta = torch::empty_like(delta);
|
||||
at::Tensor dA = torch::zeros_like(A);
|
||||
at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
||||
at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
||||
at::Tensor dD;
|
||||
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
||||
at::Tensor ddelta_bias;
|
||||
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
||||
|
||||
SSMParamsBwd params;
|
||||
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, z, out, out_z,
|
||||
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
||||
dout, du, ddelta, dA, dB, dC, dz,
|
||||
D_.has_value() ? dD.data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
||||
has_z, delta_softplus, recompute_out_z);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
||||
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
|
||||
selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
});
|
||||
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
||||
if (has_z) { result.push_back(dz); }
|
||||
if (recompute_out_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
||||
m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
||||
}
|
||||
101
Mamba/mamba-main/csrc/selective_scan/selective_scan.h
Normal file
101
Mamba/mamba-main/csrc/selective_scan/selective_scan.h
Normal file
@ -0,0 +1,101 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMScanParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, seqlen, n_chunks;
|
||||
index_t a_batch_stride;
|
||||
index_t b_batch_stride;
|
||||
index_t out_batch_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ a_ptr;
|
||||
void *__restrict__ b_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
void *__restrict__ x_ptr;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
bool delta_softplus;
|
||||
|
||||
index_t A_d_stride;
|
||||
index_t A_dstate_stride;
|
||||
index_t B_batch_stride;
|
||||
index_t B_d_stride;
|
||||
index_t B_dstate_stride;
|
||||
index_t B_group_stride;
|
||||
index_t C_batch_stride;
|
||||
index_t C_d_stride;
|
||||
index_t C_dstate_stride;
|
||||
index_t C_group_stride;
|
||||
index_t u_batch_stride;
|
||||
index_t u_d_stride;
|
||||
index_t delta_batch_stride;
|
||||
index_t delta_d_stride;
|
||||
index_t z_batch_stride;
|
||||
index_t z_d_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
void *__restrict__ B_ptr;
|
||||
void *__restrict__ C_ptr;
|
||||
void *__restrict__ D_ptr;
|
||||
void *__restrict__ u_ptr;
|
||||
void *__restrict__ delta_ptr;
|
||||
void *__restrict__ delta_bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ z_ptr;
|
||||
void *__restrict__ out_z_ptr;
|
||||
};
|
||||
|
||||
struct SSMParamsBwd: public SSMParamsBase {
|
||||
index_t dout_batch_stride;
|
||||
index_t dout_d_stride;
|
||||
index_t dA_d_stride;
|
||||
index_t dA_dstate_stride;
|
||||
index_t dB_batch_stride;
|
||||
index_t dB_group_stride;
|
||||
index_t dB_d_stride;
|
||||
index_t dB_dstate_stride;
|
||||
index_t dC_batch_stride;
|
||||
index_t dC_group_stride;
|
||||
index_t dC_d_stride;
|
||||
index_t dC_dstate_stride;
|
||||
index_t du_batch_stride;
|
||||
index_t du_d_stride;
|
||||
index_t dz_batch_stride;
|
||||
index_t dz_d_stride;
|
||||
index_t ddelta_batch_stride;
|
||||
index_t ddelta_d_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ dout_ptr;
|
||||
void *__restrict__ dA_ptr;
|
||||
void *__restrict__ dB_ptr;
|
||||
void *__restrict__ dC_ptr;
|
||||
void *__restrict__ dD_ptr;
|
||||
void *__restrict__ du_ptr;
|
||||
void *__restrict__ dz_ptr;
|
||||
void *__restrict__ ddelta_ptr;
|
||||
void *__restrict__ ddelta_bias_ptr;
|
||||
};
|
||||
@ -0,0 +1,9 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_bwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,9 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_bwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,9 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_bwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,9 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_bwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,9 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_bwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,9 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_bwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,567 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "selective_scan_common.h"
|
||||
#include "reverse_scan.cuh"
|
||||
#include "static_switch.h"
|
||||
|
||||
template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
|
||||
template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
|
||||
template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
|
||||
|
||||
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
|
||||
bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
|
||||
struct Selective_Scan_bwd_kernel_traits {
|
||||
static_assert(kNItems_ % 4 == 0);
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kNItems = kNItems_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||
static_assert(kNItems % kNElts == 0);
|
||||
static constexpr int kNLoads = kNItems / kNElts;
|
||||
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
|
||||
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
||||
static constexpr bool kHasZ = kHasZ_;
|
||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
||||
// For complex this would lead to massive register spilling, so we keep it at 2.
|
||||
static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
||||
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
||||
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
||||
using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
|
||||
using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
|
||||
|
||||
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||
sizeof(typename BlockLoadVecT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||
sizeof(typename BlockStoreT::TempStorage),
|
||||
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
|
||||
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
||||
constexpr bool kIsComplex = Ktraits::kIsComplex;
|
||||
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
using scan_t = typename Ktraits::scan_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// cast to lvalue reference of expected type
|
||||
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
||||
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
||||
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
||||
auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
|
||||
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
||||
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
||||
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
||||
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
|
||||
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
|
||||
weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int dim_id = blockIdx.y;
|
||||
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||
+ dim_id * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||
+ dim_id * params.delta_d_stride;
|
||||
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
||||
+ dim_id * params.dout_d_stride;
|
||||
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
||||
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
|
||||
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
||||
+ (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
|
||||
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
||||
+ (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
|
||||
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
|
||||
float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
||||
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
|
||||
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
||||
scan_t *x = params.x_ptr == nullptr
|
||||
? nullptr
|
||||
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
|
||||
float dD_val = 0;
|
||||
float ddelta_bias_val = 0;
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
u += (params.n_chunks - 1) * kChunkSize;
|
||||
delta += (params.n_chunks - 1) * kChunkSize;
|
||||
dout += (params.n_chunks - 1) * kChunkSize;
|
||||
Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
|
||||
Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
|
||||
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
||||
input_t u_vals[kNItems];
|
||||
input_t delta_vals_load[kNItems];
|
||||
input_t dout_vals_load[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
u -= kChunkSize;
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
// Will reload delta at the same location if kDeltaSoftplus
|
||||
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
dout -= kChunkSize;
|
||||
|
||||
float dout_vals[kNItems], delta_vals[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
dout_vals[i] = float(dout_vals_load[i]);
|
||||
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
||||
if constexpr (kDeltaSoftplus) {
|
||||
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||
+ dim_id * params.z_d_stride + chunk * kChunkSize;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ dim_id * params.out_d_stride + chunk * kChunkSize;
|
||||
input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
|
||||
+ dim_id * params.dz_d_stride + chunk * kChunkSize;
|
||||
input_t z_vals[kNItems], out_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
float dz_vals[kNItems], z_silu_vals[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
|
||||
z_silu_vals[i] = z_val * z_sigmoid_val;
|
||||
dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
|
||||
* (1.0f + z_val * (1.0f - z_sigmoid_val));
|
||||
dout_vals[i] *= z_silu_vals[i];
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
||||
if (params.out_z_ptr != nullptr) { // Recompute and store out_z
|
||||
float out_z_vals[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
|
||||
// printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
|
||||
// }
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||
+ dim_id * params.out_z_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
|
||||
float du_vals[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
|
||||
|
||||
float ddelta_vals[kNItems] = {0};
|
||||
__syncthreads();
|
||||
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||
const weight_t A_val = A[state_idx * params.A_dstate_stride];
|
||||
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||
weight_t A_scaled;
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
if constexpr (!kIsComplex) {
|
||||
A_scaled = A_val * kLog2e;
|
||||
} else {
|
||||
A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
|
||||
}
|
||||
weight_t B_val, C_val;
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (!kIsVariableB) {
|
||||
B_val = B[state_idx * params.B_dstate_stride];
|
||||
} else {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
||||
}
|
||||
if constexpr (!kIsVariableC) {
|
||||
C_val = C[state_idx * params.C_dstate_stride];
|
||||
} else {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
||||
}
|
||||
// const weight_t A_val = smem_a[state_idx];
|
||||
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
||||
if constexpr (!kIsComplex) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
|
||||
thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
||||
if (i == 0) {
|
||||
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
|
||||
} else {
|
||||
thread_reverse_data[i - 1].x = delta_a_exp;
|
||||
}
|
||||
thread_reverse_data[i].y = dout_vals[i] *
|
||||
(!kIsVariableC
|
||||
? (!kIsVariableB ? B_val * C_val : C_val)
|
||||
: (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
|
||||
}
|
||||
__syncthreads();
|
||||
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
||||
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
|
||||
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
|
||||
// Initialize running total
|
||||
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
||||
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
||||
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
||||
);
|
||||
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
||||
weight_t dA_val = 0, dBC_val = 0;
|
||||
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const float dx = thread_reverse_data[i].y;
|
||||
const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
|
||||
du_vals[i] += ddelta_u * delta_vals[i];
|
||||
const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
||||
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
|
||||
dA_val += dx * delta_vals[i] * a;
|
||||
if constexpr (!kIsVariableB || !kIsVariableC) {
|
||||
if constexpr (!kIsVariableB) { // dBC_val is dB_val
|
||||
dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
|
||||
} else { // dBC_val is dC_val
|
||||
dBC_val += dout_vals[i] * thread_data[i].y;
|
||||
}
|
||||
}
|
||||
if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
|
||||
if constexpr (kIsVariableC) {
|
||||
dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
|
||||
}
|
||||
}
|
||||
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
||||
if constexpr (kIsVariableB || kIsVariableC) {
|
||||
if constexpr (kIsVariableB) {
|
||||
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
|
||||
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
||||
}
|
||||
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
||||
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
||||
weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
if (i * kNThreads < seqlen_remaining) {
|
||||
if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
||||
if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB || !kIsVariableC) {
|
||||
float2 dA_dBC_val = make_float2(dA_val, dBC_val);
|
||||
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
|
||||
dA_val = dA_dBC_val.x;
|
||||
if (threadIdx.x == 0) {
|
||||
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
|
||||
}
|
||||
} else {
|
||||
dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
// Pytorch's implementation of complex exp (which calls thrust) is very slow
|
||||
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
|
||||
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
|
||||
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
|
||||
if (i == 0) {
|
||||
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
|
||||
} else {
|
||||
thread_reverse_data[i - 1].x = delta_a_exp.real_;
|
||||
thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
|
||||
}
|
||||
complex_t dout_BC = 2 * dout_vals[i]
|
||||
* conj(!kIsVariableC
|
||||
? (!kIsVariableB ? B_val * C_val : C_val)
|
||||
: (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
|
||||
thread_reverse_data[i].z = dout_BC.real_;
|
||||
thread_reverse_data[i].w = dout_BC.imag_;
|
||||
}
|
||||
__syncthreads();
|
||||
complex_t delta_a_exp = threadIdx.x == kNThreads - 1
|
||||
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
|
||||
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
|
||||
thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
|
||||
thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
|
||||
// Initialize running total
|
||||
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
||||
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
||||
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
||||
);
|
||||
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
||||
weight_t dA_val = 0, dBC_val = 0;
|
||||
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
|
||||
complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
|
||||
float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
|
||||
if constexpr (!kIsVariableB || !kIsVariableC) {
|
||||
if constexpr (!kIsVariableB) { // dBC_val is dB_val
|
||||
dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
|
||||
} else { // dBC_val is dC_val
|
||||
dBC_val += (2 * dout_vals[i]) * conj(x);
|
||||
}
|
||||
}
|
||||
const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
|
||||
du_vals[i] += ddelta_u * delta_vals[i];
|
||||
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
|
||||
dA_val += delta_vals[i] * dx * a_conj;
|
||||
if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
|
||||
if constexpr (kIsVariableC) {
|
||||
dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
|
||||
}
|
||||
}
|
||||
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
||||
if constexpr (kIsVariableB || kIsVariableC) {
|
||||
float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
|
||||
if constexpr (kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
dB_vals_f[i * 2] = dB_vals[i].real_;
|
||||
dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
|
||||
}
|
||||
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
dC_vals_f[i * 2] = dC_vals[i].real_;
|
||||
dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
|
||||
}
|
||||
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
|
||||
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
|
||||
}
|
||||
const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
|
||||
float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
|
||||
float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems * 2; ++i) {
|
||||
if (i * kNThreads < seqlen_remaining) {
|
||||
if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
|
||||
if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB || !kIsVariableC) {
|
||||
float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
|
||||
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
|
||||
dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
|
||||
dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
|
||||
if (threadIdx.x == 0) {
|
||||
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
|
||||
}
|
||||
} else {
|
||||
dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (kDeltaSoftplus) {
|
||||
__syncthreads();
|
||||
input_t delta_vals_load[kNItems];
|
||||
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
delta -= kChunkSize;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float delta_val = float(delta_vals_load[i]) + delta_bias;
|
||||
float delta_val_neg_exp = expf(-delta_val);
|
||||
ddelta_vals[i] = delta_val <= 20.f
|
||||
? ddelta_vals[i] / (1.f + delta_val_neg_exp)
|
||||
: ddelta_vals[i];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
|
||||
|
||||
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
||||
+ dim_id * params.du_d_stride + chunk * kChunkSize;
|
||||
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
||||
+ dim_id * params.ddelta_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
||||
|
||||
Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
|
||||
Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
|
||||
}
|
||||
if (params.dD_ptr != nullptr) {
|
||||
dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
|
||||
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
|
||||
}
|
||||
if (params.ddelta_bias_ptr != nullptr) {
|
||||
__syncthreads();
|
||||
ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
|
||||
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
|
||||
}
|
||||
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
|
||||
weight_t dBC_val;
|
||||
if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
|
||||
if constexpr (!kIsVariableB) {
|
||||
gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
|
||||
!kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
|
||||
}
|
||||
if constexpr (!kIsVariableC) {
|
||||
gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
|
||||
!kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
|
||||
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
|
||||
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
|
||||
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
|
||||
// TODO: check this
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
|
||||
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define warp_size 32
|
||||
#else
|
||||
#define warp_size ROCM_WARP_SIZE
|
||||
#endif
|
||||
|
||||
#if warp_size == 32
|
||||
if (params.seqlen <= 128) {
|
||||
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
255
Mamba/mamba-main/csrc/selective_scan/selective_scan_common.h
Normal file
255
Mamba/mamba-main/csrc/selective_scan/selective_scan_common.h
Normal file
@ -0,0 +1,255 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
#include <c10/util/complex.h> // For scalar_value_type
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define MAX_DSTATE 256
|
||||
|
||||
using complex_t = c10::complex<float>;
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename scalar_t, int N>
|
||||
struct Converter{
|
||||
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
||||
}
|
||||
};
|
||||
|
||||
template<int N>
|
||||
struct Converter<at::Half, N>{
|
||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
template<int N>
|
||||
struct Converter<at::BFloat16, N>{
|
||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
|
||||
// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
|
||||
__device__ __forceinline__ complex_t cexp2f(complex_t z) {
|
||||
float t = exp2f(z.real_);
|
||||
float c, s;
|
||||
sincosf(z.imag_, &s, &c);
|
||||
return complex_t(c * t, s * t);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ complex_t cexpf(complex_t z) {
|
||||
float t = expf(z.real_);
|
||||
float c, s;
|
||||
sincosf(z.imag_, &s, &c);
|
||||
return complex_t(c * t, s * t);
|
||||
}
|
||||
|
||||
template<typename scalar_t> struct SSMScanOp;
|
||||
|
||||
template<>
|
||||
struct SSMScanOp<float> {
|
||||
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
||||
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct SSMScanOp<complex_t> {
|
||||
__device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
|
||||
complex_t a0 = complex_t(ab0.x, ab0.y);
|
||||
complex_t b0 = complex_t(ab0.z, ab0.w);
|
||||
complex_t a1 = complex_t(ab1.x, ab1.y);
|
||||
complex_t b1 = complex_t(ab1.z, ab1.w);
|
||||
complex_t out_a = a1 * a0;
|
||||
complex_t out_b = a1 * b0 + b1;
|
||||
return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
|
||||
}
|
||||
};
|
||||
|
||||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
||||
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
||||
scan_t running_prefix;
|
||||
// Constructor
|
||||
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__device__ scan_t operator()(scan_t block_aggregate) {
|
||||
scan_t old_prefix = running_prefix;
|
||||
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_input(typename Ktraits::input_t *u,
|
||||
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
||||
reinterpret_cast<vec_t*>(u),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
||||
#ifdef USE_ROCM
|
||||
, Ktraits::kNThreads * Ktraits::kNLoads
|
||||
#endif
|
||||
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
||||
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
||||
int seqlen) {
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
if constexpr (!Ktraits::kIsComplex) {
|
||||
typename Ktraits::input_t B_vals_load[kNItems];
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||
reinterpret_cast<vec_t*>(Bvar),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||
}
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
||||
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
||||
} else {
|
||||
typename Ktraits::input_t B_vals_load[kNItems * 2];
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||
reinterpret_cast<vec_t*>(Bvar),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void store_output(typename Ktraits::input_t *out,
|
||||
const float (&out_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
||||
int seqlen) {
|
||||
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
||||
reinterpret_cast<vec_t*>(out),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_fwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,10 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_fwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,10 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Split into multiple files to compile in paralell
|
||||
|
||||
#include "selective_scan_fwd_kernel.cuh"
|
||||
|
||||
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -0,0 +1,382 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "selective_scan_common.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||
bool kIsVariableB_, bool kIsVariableC_,
|
||||
bool kHasZ_, typename input_t_, typename weight_t_>
|
||||
struct Selective_Scan_fwd_kernel_traits {
|
||||
static_assert(kNItems_ % 4 == 0);
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||
static constexpr int kNItems = kNItems_;
|
||||
static constexpr int kNRows = kNRows_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||
static_assert(kNItems % kNElts == 0);
|
||||
static constexpr int kNLoads = kNItems / kNElts;
|
||||
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
|
||||
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||
static constexpr bool kHasZ = kHasZ_;
|
||||
|
||||
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
||||
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||
sizeof(typename BlockLoadVecT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||
sizeof(typename BlockStoreT::TempStorage),
|
||||
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
constexpr bool kIsComplex = Ktraits::kIsComplex;
|
||||
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
constexpr int kNRows = Ktraits::kNRows;
|
||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
using scan_t = typename Ktraits::scan_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// cast to lvalue reference of expected type
|
||||
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
||||
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
||||
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int dim_id = blockIdx.y;
|
||||
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||
+ dim_id * kNRows * params.delta_d_stride;
|
||||
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
||||
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
float delta_bias[kNRows] = {0};
|
||||
if (params.delta_bias_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
|
||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
u += kChunkSize;
|
||||
delta += kChunkSize;
|
||||
|
||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float u_val = float(u_vals[r][i]);
|
||||
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
||||
if (params.delta_softplus) {
|
||||
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
||||
}
|
||||
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
||||
out_vals[r][i] = D_val[r] * u_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||
weight_t A_val[kNRows];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
||||
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
if constexpr (!kIsComplex) {
|
||||
A_val[r] *= kLog2e;
|
||||
} else {
|
||||
A_val[r].real_ *= kLog2e;
|
||||
}
|
||||
}
|
||||
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
||||
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
||||
// If both B and C vary, this is unused.
|
||||
weight_t BC_val[kNRows];
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (kIsVariableB) {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
||||
if constexpr (!kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB && !kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
||||
scan_t thread_data[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
if constexpr (!kIsComplex) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Pytorch's implementation of complex exp (which calls thrust) is very slow
|
||||
complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
|
||||
weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
|
||||
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
|
||||
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
scan_t running_prefix;
|
||||
if constexpr (!kIsComplex) {
|
||||
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
||||
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
|
||||
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
||||
} else {
|
||||
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
|
||||
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const weight_t C_val = !kIsVariableC
|
||||
? BC_val[r]
|
||||
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
||||
if constexpr (!kIsComplex) {
|
||||
out_vals[r][i] += thread_data[i].y * C_val;
|
||||
} else {
|
||||
out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
input_t z_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
|
||||
Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
|
||||
Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||
// processing 1 row.
|
||||
constexpr int kNRows = 1;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
|
||||
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
|
||||
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
|
||||
// Had to change this substantially since potentially the hip
|
||||
// interface for setting kernel launch attributes is slightly different from
|
||||
// cuda's. In particualar, it seems to expect a plain const void * pointer.
|
||||
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifndef USE_ROCM
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define warp_size 32
|
||||
#else
|
||||
#define warp_size ROCM_WARP_SIZE
|
||||
#endif
|
||||
|
||||
#if warp_size == 32
|
||||
if (params.seqlen <= 128) {
|
||||
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
25
Mamba/mamba-main/csrc/selective_scan/static_switch.h
Normal file
25
Mamba/mamba-main/csrc/selective_scan/static_switch.h
Normal file
@ -0,0 +1,25 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
77
Mamba/mamba-main/csrc/selective_scan/uninitialized_copy.cuh
Normal file
77
Mamba/mamba-main/csrc/selective_scan/uninitialized_copy.cuh
Normal file
@ -0,0 +1,77 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/config.cuh>
|
||||
|
||||
#include <cuda/std/type_traits>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
// Map ::cuda::std to the standard std namespace
|
||||
namespace cuda {
|
||||
namespace std = ::std;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
#if defined(_NVHPC_CUDA)
|
||||
template <typename T, typename U>
|
||||
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
||||
{
|
||||
// NVBug 3384810
|
||||
new (ptr) T(::cuda::std::forward<U>(val));
|
||||
}
|
||||
#else
|
||||
template <typename T,
|
||||
typename U,
|
||||
typename ::cuda::std::enable_if<
|
||||
::cuda::std::is_trivially_copyable<T>::value,
|
||||
int
|
||||
>::type = 0>
|
||||
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
||||
{
|
||||
*ptr = ::cuda::std::forward<U>(val);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
typename ::cuda::std::enable_if<
|
||||
!::cuda::std::is_trivially_copyable<T>::value,
|
||||
int
|
||||
>::type = 0>
|
||||
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
||||
{
|
||||
new (ptr) T(::cuda::std::forward<U>(val));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
762478
Mamba/mamba-main/data/dataset/test_data.csv
Normal file
762478
Mamba/mamba-main/data/dataset/test_data.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
Mamba/mamba-main/data/dataset/train_data.csv
Normal file
BIN
Mamba/mamba-main/data/dataset/train_data.csv
Normal file
Binary file not shown.
|
Can't render this file because it is too large.
|
39
Mamba/mamba-main/evals/lm_harness_eval.py
Normal file
39
Mamba/mamba-main/evals/lm_harness_eval.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
||||
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.models.huggingface import HFLM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
|
||||
|
||||
@register_model("mamba")
|
||||
class MambaEvalWrapper(HFLM):
|
||||
|
||||
AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
||||
|
||||
def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
|
||||
dtype=torch.float16):
|
||||
LM.__init__(self)
|
||||
self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
self._batch_size = int(batch_size) if batch_size is not None else 64
|
||||
self._max_length = max_length
|
||||
self._device = torch.device(device)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
6
Mamba/mamba-main/mamba_ssm/__init__.py
Normal file
6
Mamba/mamba-main/mamba_ssm/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
__version__ = "2.1.0"
|
||||
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
||||
from mamba_ssm.modules.mamba_simple import Mamba
|
||||
from mamba_ssm.modules.mamba2 import Mamba2
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
||||
0
Mamba/mamba-main/mamba_ssm/distributed/__init__.py
Normal file
0
Mamba/mamba-main/mamba_ssm/distributed/__init__.py
Normal file
144
Mamba/mamba-main/mamba_ssm/distributed/distributed_utils.py
Normal file
144
Mamba/mamba-main/mamba_ssm/distributed/distributed_utils.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 4 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
output = torch.empty(
|
||||
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
output = torch.empty(
|
||||
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
input_ = input_.contiguous()
|
||||
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
||||
return input_, handle
|
||||
|
||||
|
||||
class AllGatherFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_gather_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_gather = AllGatherFunc.apply
|
||||
|
||||
|
||||
class ReduceScatterFunc(torch.autograd.Function):
|
||||
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = reduce_scatter_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
reduce_scatter = ReduceScatterFunc.apply
|
||||
|
||||
|
||||
class AllReduceFunc(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatenate."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||
ctx.process_group = process_group
|
||||
output, _ = all_reduce_raw(input_, process_group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: Tensor):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
# Supports autograd, but does not support async
|
||||
all_reduce = AllReduceFunc.apply
|
||||
|
||||
|
||||
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _shared_params=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
pamams_shared = {
|
||||
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
||||
}
|
||||
for _, p in sorted(pamams_shared.items()):
|
||||
with torch.no_grad():
|
||||
# Broadcast needs src to be global rank, not group rank
|
||||
torch.distributed.broadcast(
|
||||
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
||||
)
|
||||
|
||||
|
||||
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
||||
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
params_seqparallel = {
|
||||
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
||||
}
|
||||
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
||||
if grads:
|
||||
with torch.no_grad():
|
||||
coalesced = torch._utils._flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(coalesced, group=process_group)
|
||||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
||||
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
||||
|
||||
The split may not be even across the world_size processes.
|
||||
"""
|
||||
multiple = dim // multiple_of
|
||||
div = multiple // world_size
|
||||
mod = multiple % world_size
|
||||
local_multiple = div + int(local_rank < mod)
|
||||
return local_multiple * multiple_of
|
||||
296
Mamba/mamba-main/mamba_ssm/distributed/tensor_parallel.py
Normal file
296
Mamba/mamba-main/mamba_ssm/distributed/tensor_parallel.py
Normal file
@ -0,0 +1,296 @@
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from mamba_ssm.distributed.distributed_utils import (
|
||||
all_gather_raw,
|
||||
all_reduce,
|
||||
all_reduce_raw,
|
||||
reduce_scatter,
|
||||
reduce_scatter_raw,
|
||||
)
|
||||
|
||||
|
||||
class ParallelLinearFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
||||
"""
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
||||
"""
|
||||
ctx.compute_weight_gradient = weight.requires_grad
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||
weight = weight.contiguous()
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
||||
output = F.linear(total_x, weight, bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(x, weight)
|
||||
else:
|
||||
ctx.save_for_backward(weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
if ctx.compute_weight_gradient:
|
||||
x, weight = ctx.saved_tensors
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
else:
|
||||
(weight,) = ctx.saved_tensors
|
||||
total_x = None
|
||||
batch_shape = grad_output.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = F.linear(grad_output, weight.t())
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
||||
else:
|
||||
grad_input = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
assert ctx.compute_weight_gradient
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight = torch.einsum(
|
||||
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
||||
)
|
||||
else:
|
||||
grad_weight = None
|
||||
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
def parallel_linear_func(
|
||||
x: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True,
|
||||
):
|
||||
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
||||
|
||||
|
||||
class ColumnParallelLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
process_group: ProcessGroup,
|
||||
bias: bool = True,
|
||||
sequence_parallel=True,
|
||||
multiple_of=1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if out_features % multiple_of:
|
||||
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
|
||||
multiple = out_features // multiple_of
|
||||
# We want to split @multiple across world_size, but it could be an uneven split
|
||||
div = multiple // world_size
|
||||
mod = multiple % world_size
|
||||
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
||||
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||
super().__init__(
|
||||
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
|
||||
)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def forward(self, x):
|
||||
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
||||
# we do an all_gather of x before doing the matmul.
|
||||
# If not, then the input is already gathered.
|
||||
return parallel_linear_func(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
process_group: ProcessGroup,
|
||||
bias: bool = True,
|
||||
sequence_parallel=True,
|
||||
multiple_of=1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
if in_features % multiple_of:
|
||||
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
|
||||
multiple = in_features // multiple_of
|
||||
# We want to split @multiple across world_size, but it could be an uneven split
|
||||
div = multiple // world_size
|
||||
mod = multiple % world_size
|
||||
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
||||
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
||||
# Only rank 0 will have bias
|
||||
super().__init__(
|
||||
local_multiple * multiple_of,
|
||||
out_features,
|
||||
bias=bias and rank == 0,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||
a reduce_scatter of the result.
|
||||
"""
|
||||
out = parallel_linear_func(x, self.weight, self.bias)
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Embedding):
|
||||
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if num_embeddings % world_size != 0:
|
||||
raise ValueError(
|
||||
f"num_embeddings ({num_embeddings}) must be divisible by "
|
||||
f"world_size ({world_size})"
|
||||
)
|
||||
if world_size > 1 and padding_idx is not None:
|
||||
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.process_group is None:
|
||||
return super().forward(input)
|
||||
else:
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
vocab_size = self.num_embeddings
|
||||
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
||||
input = input - vocab_start_index
|
||||
input[input_ids_mask] = 0
|
||||
embeddings = super().forward(input)
|
||||
embeddings[input_ids_mask] = 0.0
|
||||
return embeddings
|
||||
|
||||
|
||||
class ColumnParallelEmbedding(nn.Embedding):
|
||||
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if embedding_dim % world_size != 0:
|
||||
raise ValueError(
|
||||
f"embedding_dim ({embedding_dim}) must be divisible by "
|
||||
f"world_size ({world_size})"
|
||||
)
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
||||
|
||||
|
||||
class ParallelEmbeddings(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
process_group,
|
||||
padding_idx=None,
|
||||
sequence_parallel=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
embed_dim,
|
||||
padding_idx=padding_idx,
|
||||
process_group=process_group,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = ColumnParallelEmbedding(
|
||||
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
world_size = torch.distributed.get_world_size(self.process_group)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if world_size <= 1:
|
||||
embeddings = embeddings + position_embeddings
|
||||
else:
|
||||
partition_dim = self.position_embeddings.embedding_dim
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
embeddings[
|
||||
..., rank * partition_dim : (rank + 1) * partition_dim
|
||||
] += position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
||||
0
Mamba/mamba-main/mamba_ssm/models/__init__.py
Normal file
0
Mamba/mamba-main/mamba_ssm/models/__init__.py
Normal file
18
Mamba/mamba-main/mamba_ssm/models/config_mamba.py
Normal file
18
Mamba/mamba-main/mamba_ssm/models/config_mamba.py
Normal file
@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaConfig:
|
||||
|
||||
d_model: int = 2560
|
||||
d_intermediate: int = 0
|
||||
n_layer: int = 64
|
||||
vocab_size: int = 50277
|
||||
ssm_cfg: dict = field(default_factory=dict)
|
||||
attn_layer_idx: list = field(default_factory=list)
|
||||
attn_cfg: dict = field(default_factory=dict)
|
||||
rms_norm: bool = True
|
||||
residual_in_fp32: bool = True
|
||||
fused_add_norm: bool = True
|
||||
pad_vocab_size_multiple: int = 8
|
||||
tie_embeddings: bool = True
|
||||
315
Mamba/mamba-main/mamba_ssm/models/mixer_seq_simple.py
Normal file
315
Mamba/mamba-main/mamba_ssm/models/mixer_seq_simple.py
Normal file
@ -0,0 +1,315 @@
|
||||
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mamba_ssm.models.config_mamba import MambaConfig
|
||||
from mamba_ssm.modules.mamba_simple import Mamba
|
||||
from mamba_ssm.modules.mamba2 import Mamba2
|
||||
from mamba_ssm.modules.mha import MHA
|
||||
from mamba_ssm.modules.mlp import GatedMLP
|
||||
from mamba_ssm.modules.block import Block
|
||||
from mamba_ssm.utils.generation import GenerationMixin
|
||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||
|
||||
try:
|
||||
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
||||
except ImportError:
|
||||
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
||||
|
||||
# 通过create_block 创建多个处理块
|
||||
def create_block(
|
||||
d_model,
|
||||
d_intermediate,
|
||||
ssm_cfg=None,
|
||||
attn_layer_idx=None,
|
||||
attn_cfg=None,
|
||||
norm_epsilon=1e-5,
|
||||
rms_norm=False,
|
||||
residual_in_fp32=False,
|
||||
fused_add_norm=False,
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
if ssm_cfg is None:
|
||||
ssm_cfg = {}
|
||||
if attn_layer_idx is None:
|
||||
attn_layer_idx = []
|
||||
if attn_cfg is None:
|
||||
attn_cfg = {}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
if layer_idx not in attn_layer_idx:
|
||||
# Create a copy of the config to modify
|
||||
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
||||
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
||||
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
||||
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
|
||||
mixer_cls = partial(
|
||||
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
||||
layer_idx=layer_idx,
|
||||
**ssm_cfg,
|
||||
**factory_kwargs
|
||||
)
|
||||
else:
|
||||
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
||||
norm_cls = partial(
|
||||
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
||||
)
|
||||
if d_intermediate == 0:
|
||||
mlp_cls = nn.Identity
|
||||
else:
|
||||
mlp_cls = partial(
|
||||
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
|
||||
)
|
||||
block = Block(
|
||||
d_model,
|
||||
mixer_cls,
|
||||
mlp_cls,
|
||||
norm_cls=norm_cls,
|
||||
fused_add_norm=fused_add_norm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
)
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(
|
||||
module,
|
||||
n_layer,
|
||||
initializer_range=0.02, # Now only used for embedding layer.
|
||||
rescale_prenorm_residual=True,
|
||||
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
||||
):
|
||||
if isinstance(module, nn.Linear):
|
||||
if module.bias is not None:
|
||||
if not getattr(module.bias, "_no_reinit", False):
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
|
||||
if rescale_prenorm_residual:
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
||||
# We need to reinit p since this code could be called multiple times
|
||||
# Having just p *= scale would repeatedly scale it down
|
||||
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
||||
with torch.no_grad():
|
||||
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
||||
|
||||
|
||||
class MixerModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
d_intermediate: int,
|
||||
vocab_size: int,
|
||||
ssm_cfg=None,
|
||||
attn_layer_idx=None,
|
||||
attn_cfg=None,
|
||||
norm_epsilon: float = 1e-5,
|
||||
rms_norm: bool = False,
|
||||
initializer_cfg=None,
|
||||
fused_add_norm=False,
|
||||
residual_in_fp32=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
# 获取输入token的Embedding,在 MixModule中;
|
||||
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
||||
|
||||
# We change the order of residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Add, we do:
|
||||
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
||||
# This is for performance reason: we can fuse add + layer_norm.
|
||||
self.fused_add_norm = fused_add_norm
|
||||
if self.fused_add_norm:
|
||||
if layer_norm_fn is None or rms_norm_fn is None:
|
||||
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
||||
# ModuleList管理模型中的Mamba block块(Mamba block堆叠),调用create_block函数
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
create_block(
|
||||
d_model,
|
||||
d_intermediate=d_intermediate,
|
||||
ssm_cfg=ssm_cfg,
|
||||
attn_layer_idx=attn_layer_idx,
|
||||
attn_cfg=attn_cfg,
|
||||
norm_epsilon=norm_epsilon,
|
||||
rms_norm=rms_norm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
fused_add_norm=fused_add_norm,
|
||||
layer_idx=i,
|
||||
**factory_kwargs,
|
||||
)
|
||||
#循环决定输出block个数,n_layer是超参数,在config中配置
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
||||
d_model, eps=norm_epsilon, **factory_kwargs
|
||||
)
|
||||
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
|
||||
)
|
||||
)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return {
|
||||
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
for i, layer in enumerate(self.layers)
|
||||
}
|
||||
#前向传播:输入序列通过嵌入层,依次通过每个块处理,最后应用规范化层
|
||||
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
||||
#hidden_states 并非隐藏态,而是在Mamba块对输入的Embedding处理中的一个中间状态
|
||||
#每个Mamba块有hidden_states的输入和输出
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
hidden_states, residual, inference_params=inference_params
|
||||
)
|
||||
if not self.fused_add_norm:
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
hidden_states = layer_norm_fn(
|
||||
hidden_states,
|
||||
self.norm_f.weight,
|
||||
self.norm_f.bias,
|
||||
eps=self.norm_f.eps,
|
||||
residual=residual,
|
||||
prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
is_rms_norm=isinstance(self.norm_f, RMSNorm)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
initializer_cfg=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
d_model = config.d_model
|
||||
n_layer = config.n_layer
|
||||
d_intermediate = config.d_intermediate
|
||||
vocab_size = config.vocab_size
|
||||
ssm_cfg = config.ssm_cfg
|
||||
attn_layer_idx = config.attn_layer_idx
|
||||
attn_cfg = config.attn_cfg
|
||||
rms_norm = config.rms_norm
|
||||
residual_in_fp32 = config.residual_in_fp32
|
||||
fused_add_norm = config.fused_add_norm
|
||||
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
d_intermediate=d_intermediate,
|
||||
vocab_size=vocab_size,
|
||||
ssm_cfg=ssm_cfg,
|
||||
attn_layer_idx=attn_layer_idx,
|
||||
attn_cfg=attn_cfg,
|
||||
rms_norm=rms_norm,
|
||||
initializer_cfg=initializer_cfg,
|
||||
fused_add_norm=fused_add_norm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
#tie_weights 绑定预训练权重的函数:lm_head 和 embedding的权重,这里将语言模型头(lm_head)的权重
|
||||
#设置为主干网络self.backbone中词嵌入层的权重;使用于小型数据集
|
||||
def tie_weights(self):
|
||||
if self.config.tie_embeddings:
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
# 加载预训练权重函数
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||
config_data = load_config_hf(pretrained_model_name)
|
||||
config = MambaConfig(**config_data)
|
||||
model = cls(config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
||||
return model
|
||||
# 保存预训练权重函数
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
||||
Save the model and its configuration file to a directory.
|
||||
"""
|
||||
# Ensure save_directory exists
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Save the model's state_dict
|
||||
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
||||
torch.save(self.state_dict(), model_path)
|
||||
|
||||
# Save the configuration of the model
|
||||
config_path = os.path.join(save_directory, 'config.json')
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(self.config.__dict__, f, indent=4)
|
||||
0
Mamba/mamba-main/mamba_ssm/modules/__init__.py
Normal file
0
Mamba/mamba-main/mamba_ssm/modules/__init__.py
Normal file
93
Mamba/mamba-main/mamba_ssm/modules/block.py
Normal file
93
Mamba/mamba-main/mamba_ssm/modules/block.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
||||
):
|
||||
"""
|
||||
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
||||
|
||||
This Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
The standard block is: LN -> MHA/MLP -> Add.
|
||||
[Ref: https://arxiv.org/abs/2002.04745]
|
||||
Here we have: Add -> LN -> Mixer, returning both
|
||||
the hidden_states (output of the mixer) and the residual.
|
||||
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
||||
The residual needs to be provided (except for the very first block).
|
||||
"""
|
||||
super().__init__()
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.fused_add_norm = fused_add_norm
|
||||
self.norm = norm_cls(dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
if mlp_cls is not nn.Identity:
|
||||
self.norm2 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
else:
|
||||
self.mlp = None
|
||||
if self.fused_add_norm:
|
||||
assert RMSNorm is not None, "RMSNorm import fails"
|
||||
assert isinstance(
|
||||
self.norm, (nn.LayerNorm, RMSNorm)
|
||||
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
||||
|
||||
def forward(
|
||||
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
|
||||
):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: hidden_states = Mixer(LN(residual))
|
||||
"""
|
||||
if not self.fused_add_norm:
|
||||
# residual残差连接
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
# 调用norm
|
||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
hidden_states, residual = layer_norm_fn(
|
||||
hidden_states,
|
||||
self.norm.weight,
|
||||
self.norm.bias,
|
||||
residual=residual,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
eps=self.norm.eps,
|
||||
is_rms_norm=isinstance(self.norm, RMSNorm)
|
||||
)
|
||||
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
|
||||
|
||||
if self.mlp is not None:
|
||||
if not self.fused_add_norm:
|
||||
residual = hidden_states + residual
|
||||
residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
hidden_states, residual = layer_norm_fn(
|
||||
hidden_states,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
residual=residual,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
eps=self.norm2.eps,
|
||||
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
359
Mamba/mamba-main/mamba_ssm/modules/mamba2.py
Normal file
359
Mamba/mamba-main/mamba_ssm/modules/mamba2.py
Normal file
@ -0,0 +1,359 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
try:
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
except ImportError:
|
||||
causal_conv1d_fn, causal_conv1d_update = None, None
|
||||
|
||||
try:
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
except ImportError:
|
||||
selective_state_update = None
|
||||
|
||||
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
||||
|
||||
from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
||||
from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
|
||||
|
||||
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
||||
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
||||
|
||||
|
||||
class Mamba2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=128,
|
||||
d_conv=4,
|
||||
conv_init=None,
|
||||
expand=2,
|
||||
headdim=64,
|
||||
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
|
||||
ngroups=1,
|
||||
A_init_range=(1, 16),
|
||||
D_has_hdim=False,
|
||||
rmsnorm=True,
|
||||
norm_before_gate=False,
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init_floor=1e-4,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
bias=False,
|
||||
conv_bias=True,
|
||||
# Fused kernel and sharding options
|
||||
chunk_size=256,
|
||||
use_mem_eff_path=True,
|
||||
layer_idx=None, # Absorb kwarg for general module
|
||||
process_group=None,
|
||||
sequence_parallel=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.conv_init = conv_init
|
||||
self.expand = expand
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.world_size = 1 if process_group is None else process_group.size()
|
||||
self.local_rank = 0 if process_group is None else process_group.rank()
|
||||
self.d_inner = (self.expand * self.d_model) // self.world_size
|
||||
assert self.d_inner * self.world_size == self.expand * self.d_model
|
||||
self.headdim = headdim
|
||||
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
|
||||
assert ngroups % self.world_size == 0
|
||||
self.ngroups = ngroups // self.world_size
|
||||
assert self.d_ssm % self.headdim == 0
|
||||
self.nheads = self.d_ssm // self.headdim
|
||||
self.D_has_hdim = D_has_hdim
|
||||
self.rmsnorm = rmsnorm
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.dt_limit = dt_limit
|
||||
self.activation = "silu"
|
||||
self.chunk_size = chunk_size
|
||||
self.use_mem_eff_path = use_mem_eff_path
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Order: [z, x, B, C, dt]
|
||||
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
||||
if self.process_group is None:
|
||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
||||
else:
|
||||
self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
|
||||
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs)
|
||||
|
||||
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=conv_dim,
|
||||
out_channels=conv_dim,
|
||||
bias=conv_bias,
|
||||
kernel_size=d_conv,
|
||||
groups=conv_dim,
|
||||
padding=d_conv - 1,
|
||||
**factory_kwargs,
|
||||
)
|
||||
if self.conv_init is not None:
|
||||
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
# Initialize log dt bias
|
||||
dt = torch.exp(
|
||||
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
dt = torch.clamp(dt, min=dt_init_floor)
|
||||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
self.dt_bias = nn.Parameter(inv_dt)
|
||||
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
||||
# name.endswith("bias") in param_grouping.py
|
||||
self.dt_bias._no_weight_decay = True
|
||||
|
||||
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
||||
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
|
||||
A_log = torch.log(A).to(dtype=dtype)
|
||||
self.A_log = nn.Parameter(A_log)
|
||||
self.A_log._no_weight_decay = True
|
||||
|
||||
# D "skip" parameter
|
||||
self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
if self.rmsnorm:
|
||||
assert RMSNormGated is not None
|
||||
self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
|
||||
group_size=self.d_ssm // ngroups, **factory_kwargs)
|
||||
|
||||
if self.process_group is None:
|
||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||||
else:
|
||||
self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
|
||||
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs)
|
||||
|
||||
def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
|
||||
"""
|
||||
u: (batch, seqlen, hidden_dim) if seqlen=None.
|
||||
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
|
||||
split u during sequence parallel, we split the batch * seqlen dimension
|
||||
(in case batch is small).
|
||||
Returns: same shape as u
|
||||
"""
|
||||
seqlen_og = seqlen
|
||||
if seqlen is None:
|
||||
batch, seqlen, dim = u.shape
|
||||
else:
|
||||
batch_seqlen, dim = u.shape
|
||||
batch = batch_seqlen // seqlen
|
||||
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
out, _, _ = self.step(u, conv_state, ssm_state)
|
||||
return out
|
||||
|
||||
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
|
||||
if seqlen_og is not None:
|
||||
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
|
||||
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
||||
if self.use_mem_eff_path and inference_params is None:
|
||||
out = mamba_split_conv1d_scan_combined(
|
||||
zxbcdt,
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias,
|
||||
self.dt_bias,
|
||||
A,
|
||||
D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
|
||||
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
|
||||
outproj_weight=self.out_proj.weight,
|
||||
outproj_bias=self.out_proj.bias,
|
||||
headdim=None if self.D_has_hdim else self.headdim,
|
||||
ngroups=self.ngroups,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
**dt_limit_kwargs,
|
||||
)
|
||||
if seqlen_og is not None:
|
||||
out = rearrange(out, "b l d -> (b l) d")
|
||||
if self.process_group is not None:
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
out = reduce_fn(out, self.process_group)
|
||||
else:
|
||||
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
|
||||
z0, x0, z, xBC, dt = torch.split(
|
||||
zxbcdt,
|
||||
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
|
||||
dim=-1
|
||||
)
|
||||
if conv_state is not None:
|
||||
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||
xBC_t = rearrange(xBC, "b l d -> b d l")
|
||||
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
|
||||
assert self.activation in ["silu", "swish"]
|
||||
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
||||
xBC = self.act(
|
||||
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
||||
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
|
||||
else:
|
||||
xBC = causal_conv1d_fn(
|
||||
xBC.transpose(1, 2),
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
).transpose(1, 2)
|
||||
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
||||
y = mamba_chunk_scan_combined(
|
||||
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
||||
dt,
|
||||
A,
|
||||
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
||||
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
||||
chunk_size=self.chunk_size,
|
||||
D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
|
||||
z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
seq_idx=seq_idx,
|
||||
**dt_limit_kwargs,
|
||||
return_final_states=ssm_state is not None,
|
||||
)
|
||||
if ssm_state is not None:
|
||||
y, last_state = y
|
||||
ssm_state.copy_(last_state)
|
||||
y = rearrange(y, "b l h p -> b l (h p)")
|
||||
if self.rmsnorm:
|
||||
y = self.norm(y, z)
|
||||
if d_mlp > 0:
|
||||
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
||||
if seqlen_og is not None:
|
||||
y = rearrange(y, "b l d -> (b l) d")
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
|
||||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
dtype = hidden_states.dtype
|
||||
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||||
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
|
||||
z0, x0, z, xBC, dt = torch.split(
|
||||
zxbcdt,
|
||||
[d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
|
||||
dim=-1
|
||||
)
|
||||
|
||||
# Conv step
|
||||
if causal_conv1d_update is None:
|
||||
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
||||
conv_state[:, :, -1] = xBC
|
||||
xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
||||
if self.conv1d.bias is not None:
|
||||
xBC = xBC + self.conv1d.bias
|
||||
xBC = self.act(xBC).to(dtype=dtype)
|
||||
else:
|
||||
xBC = causal_conv1d_update(
|
||||
xBC,
|
||||
conv_state,
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
|
||||
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
||||
A = -torch.exp(self.A_log.float()) # (nheads,)
|
||||
|
||||
# SSM step
|
||||
if selective_state_update is None:
|
||||
assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
|
||||
# Discretize A and B
|
||||
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
|
||||
dA = torch.exp(dt * A) # (batch, nheads)
|
||||
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
||||
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
|
||||
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
||||
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
|
||||
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
|
||||
y = rearrange(y, "b h p -> b (h p)")
|
||||
if not self.rmsnorm:
|
||||
y = y * self.act(z) # (B D)
|
||||
else:
|
||||
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
|
||||
dt = repeat(dt, "b h -> b h p", p=self.headdim)
|
||||
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
|
||||
D = repeat(self.D, "h -> h p", p=self.headdim)
|
||||
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
|
||||
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
|
||||
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
||||
if not self.rmsnorm:
|
||||
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
|
||||
y = selective_state_update(
|
||||
ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
|
||||
dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
y = rearrange(y, "b h p -> b (h p)")
|
||||
if self.rmsnorm:
|
||||
y = self.norm(y, z)
|
||||
if d_mlp > 0:
|
||||
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
device = self.out_proj.weight.device
|
||||
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
||||
conv_state = torch.zeros(
|
||||
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype
|
||||
)
|
||||
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
|
||||
ssm_state = torch.zeros(
|
||||
batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
|
||||
)
|
||||
return conv_state, ssm_state
|
||||
|
||||
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
||||
assert self.layer_idx is not None
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
batch_shape = (batch_size,)
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
self.conv1d.weight.shape[0],
|
||||
self.d_conv,
|
||||
device=self.conv1d.weight.device,
|
||||
dtype=self.conv1d.weight.dtype,
|
||||
)
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
self.nheads,
|
||||
self.headdim,
|
||||
self.d_state,
|
||||
device=self.in_proj.weight.device,
|
||||
dtype=self.in_proj.weight.dtype,
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
||||
else:
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
# TODO: What if batch size changes between generation, and we reuse the same states?
|
||||
if initialize_states:
|
||||
conv_state.zero_()
|
||||
ssm_state.zero_()
|
||||
return conv_state, ssm_state
|
||||
199
Mamba/mamba-main/mamba_ssm/modules/mamba2_simple.py
Normal file
199
Mamba/mamba-main/mamba_ssm/modules/mamba2_simple.py
Normal file
@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
try:
|
||||
from causal_conv1d import causal_conv1d_fn
|
||||
except ImportError:
|
||||
causal_conv1d_fn = None
|
||||
|
||||
try:
|
||||
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
||||
except ImportError:
|
||||
RMSNormGated, LayerNorm = None, None
|
||||
|
||||
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
||||
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
||||
|
||||
|
||||
class Mamba2Simple(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=64,
|
||||
d_conv=4,
|
||||
conv_init=None,
|
||||
expand=2,
|
||||
headdim=128,
|
||||
ngroups=1,
|
||||
A_init_range=(1, 16),
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init_floor=1e-4,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
learnable_init_states=False,
|
||||
activation="swish",
|
||||
bias=False,
|
||||
conv_bias=True,
|
||||
# Fused kernel and sharding options
|
||||
chunk_size=256,
|
||||
use_mem_eff_path=True,
|
||||
layer_idx=None, # Absorb kwarg for general module
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.conv_init = conv_init
|
||||
self.expand = expand
|
||||
self.d_inner = self.expand * self.d_model
|
||||
self.headdim = headdim
|
||||
self.ngroups = ngroups
|
||||
assert self.d_inner % self.headdim == 0
|
||||
self.nheads = self.d_inner // self.headdim
|
||||
self.dt_limit = dt_limit
|
||||
self.learnable_init_states = learnable_init_states
|
||||
self.activation = activation
|
||||
self.chunk_size = chunk_size
|
||||
self.use_mem_eff_path = use_mem_eff_path
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Order: [z, x, B, C, dt]
|
||||
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
||||
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
||||
|
||||
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=conv_dim,
|
||||
out_channels=conv_dim,
|
||||
bias=conv_bias,
|
||||
kernel_size=d_conv,
|
||||
groups=conv_dim,
|
||||
padding=d_conv - 1,
|
||||
**factory_kwargs,
|
||||
)
|
||||
if self.conv_init is not None:
|
||||
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
||||
# self.conv1d.weight._no_weight_decay = True
|
||||
|
||||
if self.learnable_init_states:
|
||||
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
|
||||
self.init_states._no_weight_decay = True
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
# Initialize log dt bias
|
||||
dt = torch.exp(
|
||||
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
dt = torch.clamp(dt, min=dt_init_floor)
|
||||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
self.dt_bias = nn.Parameter(inv_dt)
|
||||
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
||||
# name.endswith("bias") in param_grouping.py
|
||||
self.dt_bias._no_weight_decay = True
|
||||
|
||||
# A parameter
|
||||
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
||||
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
|
||||
A_log = torch.log(A).to(dtype=dtype)
|
||||
self.A_log = nn.Parameter(A_log)
|
||||
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
||||
self.A_log._no_weight_decay = True
|
||||
|
||||
# D "skip" parameter
|
||||
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
# Extra normalization layer right before output projection
|
||||
assert RMSNormGated is not None
|
||||
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
|
||||
|
||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, u, seq_idx=None):
|
||||
"""
|
||||
u: (B, L, D)
|
||||
Returns: same shape as u
|
||||
"""
|
||||
batch, seqlen, dim = u.shape
|
||||
|
||||
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
||||
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
||||
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
|
||||
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
||||
|
||||
if self.use_mem_eff_path:
|
||||
# Fully fused path
|
||||
out = mamba_split_conv1d_scan_combined(
|
||||
zxbcdt,
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias,
|
||||
self.dt_bias,
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.eps,
|
||||
outproj_weight=self.out_proj.weight,
|
||||
outproj_bias=self.out_proj.bias,
|
||||
headdim=self.headdim,
|
||||
ngroups=self.ngroups,
|
||||
norm_before_gate=False,
|
||||
initial_states=initial_states,
|
||||
**dt_limit_kwargs,
|
||||
)
|
||||
else:
|
||||
z, xBC, dt = torch.split(
|
||||
zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
|
||||
)
|
||||
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
||||
assert self.activation in ["silu", "swish"]
|
||||
|
||||
# 1D Convolution
|
||||
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
||||
xBC = self.act(
|
||||
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
||||
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
||||
else:
|
||||
xBC = causal_conv1d_fn(
|
||||
x=xBC.transpose(1, 2),
|
||||
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
).transpose(1, 2)
|
||||
|
||||
# Split into 3 main branches: X, B, C
|
||||
# These correspond to V, K, Q respectively in the SSM/attention duality
|
||||
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
|
||||
y = mamba_chunk_scan_combined(
|
||||
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
||||
dt,
|
||||
A,
|
||||
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
||||
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=seq_idx,
|
||||
initial_states=initial_states,
|
||||
**dt_limit_kwargs,
|
||||
)
|
||||
y = rearrange(y, "b l h p -> b l (h p)")
|
||||
|
||||
# Multiply "gate" branch and apply extra normalization layer
|
||||
y = self.norm(y, z)
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
300
Mamba/mamba-main/mamba_ssm/modules/mamba_simple.py
Normal file
300
Mamba/mamba-main/mamba_ssm/modules/mamba_simple.py
Normal file
@ -0,0 +1,300 @@
|
||||
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
||||
|
||||
try:#引入加速卷积
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
except ImportError:
|
||||
causal_conv1d_fn, causal_conv1d_update = None, None
|
||||
|
||||
try:
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
except ImportError:
|
||||
selective_state_update = None
|
||||
|
||||
try:
|
||||
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
||||
except ImportError:
|
||||
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=16,
|
||||
d_conv=4,#卷积核的大小
|
||||
expand=2,#意味着d_inner 是 d_model的两倍
|
||||
dt_rank="auto",
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init="random",
|
||||
dt_scale=1.0,
|
||||
dt_init_floor=1e-4,
|
||||
conv_bias=True,
|
||||
bias=False,
|
||||
use_fast_path=True, # Fused kernel options
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.expand = expand
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||
self.use_fast_path = use_fast_path
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
||||
#nn.Conv1d的实例化
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.d_inner,
|
||||
out_channels=self.d_inner,
|
||||
bias=conv_bias,
|
||||
kernel_size=d_conv,
|
||||
groups=self.d_inner,
|
||||
padding=d_conv - 1,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self.activation = "silu"
|
||||
self.act = nn.SiLU()
|
||||
|
||||
self.x_proj = nn.Linear(
|
||||
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
||||
|
||||
# Initialize special dt projection to preserve variance at initialization
|
||||
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
||||
if dt_init == "constant":
|
||||
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
||||
elif dt_init == "random":
|
||||
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
||||
dt = torch.exp(
|
||||
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
).clamp(min=dt_init_floor)
|
||||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
with torch.no_grad():
|
||||
self.dt_proj.bias.copy_(inv_dt)
|
||||
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
||||
self.dt_proj.bias._no_reinit = True
|
||||
|
||||
# S4D real initialization
|
||||
A = repeat(
|
||||
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
||||
"n -> d n",
|
||||
d=self.d_inner,
|
||||
).contiguous()
|
||||
A_log = torch.log(A) # Keep A_log in fp32
|
||||
self.A_log = nn.Parameter(A_log)
|
||||
self.A_log._no_weight_decay = True
|
||||
|
||||
# D "skip" parameter
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||||
# 前向传播,包括各个计算模块的处理;
|
||||
def forward(self, hidden_states, inference_params=None):
|
||||
"""
|
||||
hidden_states: (B, L, D)
|
||||
Returns: same shape as hidden_states
|
||||
"""
|
||||
batch, seqlen, dim = hidden_states.shape
|
||||
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:#只在推理的时候应用step
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
# 将embedding隐藏状态传入step函数
|
||||
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
||||
return out
|
||||
|
||||
# We do matmul and transpose BLH -> HBL at the same time
|
||||
xz = rearrange(
|
||||
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
||||
"d (b l) -> b d l",
|
||||
l=seqlen,
|
||||
)
|
||||
if self.in_proj.bias is not None:
|
||||
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
||||
|
||||
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
||||
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
||||
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
|
||||
#前向转播:快速路径(常规路径),提高计算效率
|
||||
out = mamba_inner_fn(#该函数做前向反向传播
|
||||
xz,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
self.x_proj.weight,
|
||||
self.dt_proj.weight,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
A,
|
||||
None, # input-dependent B
|
||||
None, # input-dependent C
|
||||
self.D.float(),
|
||||
delta_bias=self.dt_proj.bias.float(),
|
||||
delta_softplus=True,
|
||||
)
|
||||
else:#常规路径
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
if conv_state is not None:
|
||||
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
||||
#检查是否有因果卷积
|
||||
if causal_conv1d_fn is None:
|
||||
x = self.act(self.conv1d(x)[..., :seqlen])
|
||||
else:
|
||||
assert self.activation in ["silu", "swish"]
|
||||
x = causal_conv1d_fn(
|
||||
x=x,
|
||||
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
# We're careful here about the layout, to avoid extra transposes.
|
||||
# We want dt to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
dt = self.dt_proj.weight @ dt.t()
|
||||
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
assert self.activation in ["silu", "swish"]
|
||||
y = selective_scan_fn(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.D.float(),
|
||||
z=z,
|
||||
delta_bias=self.dt_proj.bias.float(),
|
||||
delta_softplus=True,
|
||||
return_last_state=ssm_state is not None,
|
||||
)
|
||||
if ssm_state is not None:
|
||||
y, last_state = y
|
||||
ssm_state.copy_(last_state)
|
||||
y = rearrange(y, "b d l -> b l d")
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
# step 方法用于**状态空间**解码过程中的单步更新,允许一个接一个地生成序列的下一个元素。
|
||||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
dtype = hidden_states.dtype
|
||||
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||||
# hidden_states经过in_proj的处理
|
||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
# 拆分xz,x、z都是d_inner的维度
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
|
||||
# Conv step 卷积步骤,判断是否导入causal_conv1d进行卷积加速
|
||||
if causal_conv1d_update is None:
|
||||
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
||||
conv_state[:, :, -1] = x
|
||||
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
||||
if self.conv1d.bias is not None:
|
||||
x = x + self.conv1d.bias
|
||||
x = self.act(x).to(dtype=dtype)
|
||||
else:
|
||||
x = causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
|
||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||
# Don't add dt_bias here
|
||||
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
||||
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
||||
|
||||
# SSM step
|
||||
if selective_state_update is None:
|
||||
# Discretize A and B
|
||||
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
||||
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
||||
dB = torch.einsum("bd,bn->bdn", dt, B)
|
||||
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
||||
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
||||
y = y + self.D.to(dtype) * x
|
||||
y = y * self.act(z) # (B D)
|
||||
else:
|
||||
#提高计算速度:
|
||||
y = selective_state_update(
|
||||
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||
)
|
||||
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
device = self.out_proj.weight.device
|
||||
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
||||
conv_state = torch.zeros(
|
||||
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
||||
)
|
||||
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
||||
# ssm_dtype = torch.float32
|
||||
ssm_state = torch.zeros(
|
||||
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
||||
)
|
||||
return conv_state, ssm_state
|
||||
|
||||
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
||||
assert self.layer_idx is not None
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
batch_shape = (batch_size,)
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_conv,
|
||||
device=self.conv1d.weight.device,
|
||||
dtype=self.conv1d.weight.dtype,
|
||||
)
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_state,
|
||||
device=self.dt_proj.weight.device,
|
||||
dtype=self.dt_proj.weight.dtype,
|
||||
# dtype=torch.float32,
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
||||
else:
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
# TODO: What if batch size changes between generation, and we reuse the same states?
|
||||
if initialize_states:
|
||||
conv_state.zero_()
|
||||
ssm_state.zero_()
|
||||
return conv_state, ssm_state
|
||||
289
Mamba/mamba-main/mamba_ssm/modules/mha.py
Normal file
289
Mamba/mamba-main/mamba_ssm/modules/mha.py
Normal file
@ -0,0 +1,289 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
except ImportError:
|
||||
flash_attn_with_kvcache = None
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
except ImportError:
|
||||
RotaryEmbedding = None
|
||||
|
||||
try:
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
except ImportError:
|
||||
causal_conv1d_fn, causal_conv1d_update = None, None
|
||||
|
||||
|
||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
num_heads, head_dim = kv.shape[-2:]
|
||||
assert layer_idx in inference_params.key_value_memory_dict
|
||||
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.seqlen_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= kv_cache.shape[0]
|
||||
assert sequence_end <= kv_cache.shape[1]
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
"""Multi-head self-attention and cross-attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
num_heads_kv=None,
|
||||
head_dim=None, # If None, use embed_dim // num_heads
|
||||
mlp_dim=0,
|
||||
qkv_proj_bias=True,
|
||||
out_proj_bias=True,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
layer_idx=None,
|
||||
d_conv=0,
|
||||
rotary_emb_dim=0,
|
||||
rotary_emb_base=10000.0,
|
||||
rotary_emb_interleaved=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
"""
|
||||
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
||||
return_residual: whether to return the input x along with the output. This is for
|
||||
performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.layer_idx = layer_idx
|
||||
self.d_conv = d_conv
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.softmax_scale = softmax_scale
|
||||
self.causal = causal
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||
assert (
|
||||
self.num_heads % self.num_heads_kv == 0
|
||||
), "num_heads must be divisible by num_heads_kv"
|
||||
if head_dim is None:
|
||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
||||
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
out_dim = self.head_dim * self.num_heads
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_emb_dim,
|
||||
base=rotary_emb_base,
|
||||
interleaved=rotary_emb_interleaved,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
|
||||
if self.d_conv > 0:
|
||||
self.conv1d = nn.Conv1d(
|
||||
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
|
||||
**factory_kwargs
|
||||
)
|
||||
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
||||
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
||||
device = self.out_proj.weight.device
|
||||
if self.d_conv > 0:
|
||||
conv_state = torch.zeros(
|
||||
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
conv_state = None
|
||||
kv_cache = torch.empty(
|
||||
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
|
||||
)
|
||||
return kv_cache, conv_state
|
||||
|
||||
def _update_kv_cache(self, kv, inference_params):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
||||
"""
|
||||
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
||||
"""
|
||||
assert inference_params is not None and inference_params.seqlen_offset > 0
|
||||
if self.rotary_emb_dim > 0:
|
||||
self.rotary_emb._update_cos_sin_cache(
|
||||
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
||||
)
|
||||
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
||||
else:
|
||||
rotary_cos, rotary_sin = None, None
|
||||
batch = q.shape[0]
|
||||
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
kv_cache = kv_cache[:batch]
|
||||
cache_seqlens = (
|
||||
inference_params.lengths_per_sample[:batch]
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
kv_cache[:, :, 0],
|
||||
kv_cache[:, :, 1],
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
rotary_cos=rotary_cos,
|
||||
rotary_sin=rotary_sin,
|
||||
cache_seqlens=cache_seqlens,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
||||
)
|
||||
return context
|
||||
|
||||
def _update_kvcache_attention(self, q, kv, inference_params):
|
||||
"""Write kv to inference_params, then do attention"""
|
||||
if (
|
||||
inference_params.seqlen_offset == 0
|
||||
or flash_attn_with_kvcache is None
|
||||
):
|
||||
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
k, v = kv.unbind(dim=-3)
|
||||
return F.scaled_dot_product_attention(
|
||||
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
||||
).transpose(1, 2)
|
||||
else:
|
||||
batch = q.shape[0]
|
||||
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
||||
cache_seqlens = (
|
||||
inference_params.lengths_per_sample[:batch]
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
return flash_attn_with_kvcache(
|
||||
q,
|
||||
kv_cache[:, :, 0],
|
||||
kv_cache[:, :, 1],
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
cache_seqlens=cache_seqlens,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
)
|
||||
|
||||
def forward(self, x, inference_params=None):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
||||
is the is the sum of the sequence lengths in the batch.
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
"""
|
||||
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
||||
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
||||
)
|
||||
seqlen_offset = (
|
||||
0
|
||||
if inference_params is None
|
||||
else (
|
||||
inference_params.lengths_per_sample
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
)
|
||||
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
||||
qkv = self.in_proj(x)
|
||||
if self.mlp_dim > 0:
|
||||
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
|
||||
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
|
||||
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
|
||||
if self.d_conv > 0:
|
||||
# The inference code for conv1d is pretty messy, should clean it up
|
||||
if (inference_params is None or inference_params.seqlen_offset == 0):
|
||||
if causal_conv1d_fn is None:
|
||||
qkv = rearrange(
|
||||
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
|
||||
).contiguous()
|
||||
else:
|
||||
qkv = causal_conv1d_fn(
|
||||
qkv.transpose(1, 2),
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias
|
||||
).transpose(1, 2)
|
||||
if inference_params is not None:
|
||||
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||
qkv_t = rearrange(qkv, "b l d -> b d l")
|
||||
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
|
||||
else:
|
||||
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||||
qkv = qkv.squeeze(1)
|
||||
# Conv step
|
||||
if causal_conv1d_update is None:
|
||||
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
||||
conv_state[:, :, -1] = qkv
|
||||
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
||||
if self.conv1d.bias is not None:
|
||||
qkv = qkv + self.conv1d.bias
|
||||
else:
|
||||
qkv = causal_conv1d_update(
|
||||
qkv,
|
||||
conv_state,
|
||||
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
||||
self.conv1d.bias
|
||||
)
|
||||
qkv = qkv.unsqueeze(1)
|
||||
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
|
||||
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
||||
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
||||
if (
|
||||
inference_params is None
|
||||
or inference_params.seqlen_offset == 0
|
||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||
):
|
||||
if self.rotary_emb_dim > 0:
|
||||
q, kv = self.rotary_emb(
|
||||
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||
)
|
||||
if inference_params is None:
|
||||
k, v = kv.unbind(dim=-3)
|
||||
context = F.scaled_dot_product_attention(
|
||||
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
||||
).transpose(1, 2)
|
||||
else:
|
||||
context = self._update_kvcache_attention(q, kv, inference_params)
|
||||
else:
|
||||
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
||||
context = rearrange(context, "... h d -> ... (h d)")
|
||||
if self.mlp_dim > 0:
|
||||
context = torch.cat([context, x_mlp], dim=-1)
|
||||
out = self.out_proj(context)
|
||||
return out
|
||||
34
Mamba/mamba-main/mamba_ssm/modules/mlp.py
Normal file
34
Mamba/mamba-main/mamba_ssm/modules/mlp.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class GatedMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
activation=F.silu,
|
||||
bias=False,
|
||||
multiple_of=128,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
out_features = out_features if out_features is not None else in_features
|
||||
hidden_features = (
|
||||
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
||||
)
|
||||
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
||||
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
||||
self.activation = activation
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
y, gate = y.chunk(2, dim=-1)
|
||||
y = y * self.activation(gate)
|
||||
y = self.fc2(y)
|
||||
return y
|
||||
103
Mamba/mamba-main/mamba_ssm/modules/ssd_minimal.py
Normal file
103
Mamba/mamba-main/mamba_ssm/modules/ssd_minimal.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2024, Albert Gu and Tri Dao.
|
||||
"""Minimal implementation of SSD.
|
||||
|
||||
This is the same as Listing 1 from the paper.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
||||
|
||||
|
||||
def segsum_unstable(x):
|
||||
"""Naive segment sum calculation."""
|
||||
T = x.size(-1)
|
||||
x_cumsum = torch.cumsum(x, dim=-1)
|
||||
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
def segsum(x):
|
||||
"""More stable segment sum calculation."""
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
||||
"""
|
||||
Arguments:
|
||||
X: (batch, length, n_heads, d_head)
|
||||
A: (batch, length, n_heads)
|
||||
B: (batch, length, n_heads, d_state)
|
||||
C: (batch, length, n_heads, d_state)
|
||||
Return:
|
||||
Y: (batch, length, n_heads, d_head)
|
||||
"""
|
||||
assert X.dtype == A.dtype == B.dtype == C.dtype
|
||||
assert X.shape[1] % block_len == 0
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
L = torch.exp(segsum(A))
|
||||
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
||||
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
|
||||
return Y, final_state
|
||||
|
||||
|
||||
# Simple test
|
||||
def test_correctness():
|
||||
torch.manual_seed(42)
|
||||
|
||||
## Dimensions
|
||||
# Denoted (B, T, Q, D, P) in the paper
|
||||
batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
|
||||
nheads = dim // headdim # (H) in the paper
|
||||
ngroups = 1 # (G) in the paper
|
||||
dstate = 64 # (N) in the paper
|
||||
dtype = torch.float32
|
||||
device = "cuda"
|
||||
|
||||
x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
|
||||
dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()
|
||||
A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()
|
||||
B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
||||
C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
||||
D = torch.randn(nheads, dtype=dtype, device=device)
|
||||
|
||||
# Comparing fused version and minimal version
|
||||
y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
|
||||
y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size)
|
||||
0
Mamba/mamba-main/mamba_ssm/ops/__init__.py
Normal file
0
Mamba/mamba-main/mamba_ssm/ops/__init__.py
Normal file
357
Mamba/mamba-main/mamba_ssm/ops/selective_scan_interface.py
Normal file
357
Mamba/mamba-main/mamba_ssm/ops/selective_scan_interface.py
Normal file
@ -0,0 +1,357 @@
|
||||
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
try:
|
||||
from causal_conv1d import causal_conv1d_fn
|
||||
import causal_conv1d_cuda
|
||||
except ImportError:
|
||||
causal_conv1d_fn = None
|
||||
causal_conv1d_cuda = None
|
||||
|
||||
import selective_scan_cuda
|
||||
|
||||
|
||||
class SelectiveScanFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
||||
return_last_state=False):
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3:
|
||||
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
||||
ctx.squeeze_B = True
|
||||
if C.dim() == 3:
|
||||
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
||||
ctx.squeeze_C = True
|
||||
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
||||
ctx.delta_softplus = delta_softplus
|
||||
ctx.has_z = z is not None
|
||||
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
||||
if not ctx.has_z:
|
||||
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
else:
|
||||
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
||||
out_z = rest[0]
|
||||
return out_z if not return_last_state else (out_z, last_state)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
if not ctx.has_z:
|
||||
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
||||
z = None
|
||||
out = None
|
||||
else:
|
||||
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan_cuda with the backward of chunk).
|
||||
# Here we just pass in None and dz will be allocated in the C++ code.
|
||||
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
||||
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
||||
False # option to recompute out_z, not used here
|
||||
)
|
||||
dz = rest[0] if ctx.has_z else None
|
||||
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
||||
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
||||
return (du, ddelta, dA, dB, dC,
|
||||
dD if D is not None else None,
|
||||
dz,
|
||||
ddelta_bias if delta_bias is not None else None,
|
||||
None,
|
||||
None)
|
||||
|
||||
|
||||
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
||||
return_last_state=False):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
||||
not considered in the backward pass.
|
||||
"""
|
||||
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
||||
|
||||
|
||||
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
||||
return_last_state=False):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
if A.is_complex():
|
||||
if is_variable_B:
|
||||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||||
if is_variable_C:
|
||||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||||
else:
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate))
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum('bdn,dn->bd', x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
if y.is_complex():
|
||||
y = y.real * 2
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
class MambaInnerFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
||||
out_proj_weight, out_proj_bias,
|
||||
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
||||
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
||||
"""
|
||||
xz: (batch, dim, seqlen)
|
||||
"""
|
||||
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
||||
assert checkpoint_lvl in [0, 1]
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
if torch.is_autocast_enabled():
|
||||
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
if out_proj_bias is not None else None)
|
||||
if xz.stride(-1) != 1:
|
||||
xz = xz.contiguous()
|
||||
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
||||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
||||
x, conv1d_weight, conv1d_bias, None, None, None, True
|
||||
)
|
||||
# We're being very careful here about the layout, to avoid extra transposes.
|
||||
# We want delta to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
||||
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
||||
ctx.is_variable_B = B is None
|
||||
ctx.is_variable_C = C is None
|
||||
ctx.B_proj_bias_is_None = B_proj_bias is None
|
||||
ctx.C_proj_bias_is_None = C_proj_bias is None
|
||||
if B is None: # variable B
|
||||
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
||||
if B_proj_bias is not None:
|
||||
B = B + B_proj_bias.to(dtype=B.dtype)
|
||||
if not A.is_complex():
|
||||
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
else:
|
||||
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
||||
else:
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C is None: # variable C
|
||||
C = x_dbl[:, -d_state:] # (bl dstate)
|
||||
if C_proj_bias is not None:
|
||||
C = C + C_proj_bias.to(dtype=C.dtype)
|
||||
if not A.is_complex():
|
||||
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
else:
|
||||
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
||||
else:
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
||||
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
||||
)
|
||||
ctx.delta_softplus = delta_softplus
|
||||
ctx.out_proj_bias_is_None = out_proj_bias is None
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
||||
conv1d_out, delta = None, None
|
||||
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
||||
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
||||
A, B, C, D, delta_bias, scan_intermediates, out)
|
||||
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, dout):
|
||||
# dout: (batch, seqlen, dim)
|
||||
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
||||
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
||||
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
if ctx.checkpoint_lvl == 1:
|
||||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
||||
x, conv1d_weight, conv1d_bias, None, None, None, True
|
||||
)
|
||||
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
||||
"d (b l) -> b d l", l = L)
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan_cuda with the backward of chunk).
|
||||
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
||||
dx, dz = dxz.chunk(2, dim=1)
|
||||
dout = rearrange(dout, "b l e -> e (b l)")
|
||||
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
||||
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
||||
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
||||
ctx.delta_softplus,
|
||||
True # option to recompute out_z
|
||||
)
|
||||
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
||||
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
||||
dD = dD if D is not None else None
|
||||
dx_dbl = torch.empty_like(x_dbl)
|
||||
dB_proj_bias = None
|
||||
if ctx.is_variable_B:
|
||||
if not A.is_complex():
|
||||
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
||||
else:
|
||||
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
||||
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
||||
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
||||
dB = None
|
||||
dC_proj_bias = None
|
||||
if ctx.is_variable_C:
|
||||
if not A.is_complex():
|
||||
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
||||
else:
|
||||
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
||||
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
||||
dx_dbl[:, -d_state:] = dC # (bl d)
|
||||
dC = None
|
||||
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
||||
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
||||
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
||||
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
||||
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
||||
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
||||
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
||||
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
||||
# backward of conv1d with the backward of chunk).
|
||||
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
||||
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
||||
)
|
||||
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
||||
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
||||
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
||||
dout_proj_weight, dout_proj_bias,
|
||||
dA, dB, dC, dD,
|
||||
ddelta_bias if delta_bias is not None else None,
|
||||
dB_proj_bias, dC_proj_bias, None)
|
||||
|
||||
|
||||
def mamba_inner_fn(
|
||||
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
||||
out_proj_weight, out_proj_bias,
|
||||
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
||||
C_proj_bias=None, delta_softplus=True
|
||||
):
|
||||
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
||||
out_proj_weight, out_proj_bias,
|
||||
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
||||
|
||||
|
||||
def mamba_inner_ref(
|
||||
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
||||
out_proj_weight, out_proj_bias,
|
||||
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
||||
C_proj_bias=None, delta_softplus=True
|
||||
):
|
||||
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
|
||||
# We're being very careful here about the layout, to avoid extra transposes.
|
||||
# We want delta to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
||||
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
||||
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
||||
if B is None: # variable B
|
||||
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
||||
if B_proj_bias is not None:
|
||||
B = B + B_proj_bias.to(dtype=B.dtype)
|
||||
if not A.is_complex():
|
||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
else:
|
||||
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
||||
if C is None: # variable B
|
||||
C = x_dbl[:, -d_state:] # (bl d)
|
||||
if C_proj_bias is not None:
|
||||
C = C + C_proj_bias.to(dtype=C.dtype)
|
||||
if not A.is_complex():
|
||||
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
else:
|
||||
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
||||
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
||||
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
||||
0
Mamba/mamba-main/mamba_ssm/ops/triton/__init__.py
Normal file
0
Mamba/mamba-main/mamba_ssm/ops/triton/__init__.py
Normal file
153
Mamba/mamba-main/mamba_ssm/ops/triton/k_activations.py
Normal file
153
Mamba/mamba-main/mamba_ssm/ops/triton/k_activations.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_N': 32}),
|
||||
triton.Config({'BLOCK_N': 64}),
|
||||
triton.Config({'BLOCK_N': 128}),
|
||||
triton.Config({'BLOCK_N': 256}),
|
||||
triton.Config({'BLOCK_N': 512}),
|
||||
triton.Config({'BLOCK_N': 1024}),
|
||||
],
|
||||
key=['ncols'],
|
||||
)
|
||||
@triton.jit
|
||||
def _swiglu_fwd_kernel(
|
||||
X,
|
||||
Y,
|
||||
OUT,
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_out_row,
|
||||
ncols,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
start_col = tl.program_id(1) * BLOCK_N
|
||||
X += row * stride_x_row
|
||||
Y += row * stride_y_row
|
||||
OUT += row * stride_out_row
|
||||
cols = start_col + tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
||||
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
||||
out = x * tl.sigmoid(x) * y
|
||||
tl.store(OUT + cols, out, mask=cols < ncols)
|
||||
|
||||
|
||||
def _swiglu_fwd(xy, out=None):
|
||||
if xy.stride(-1) != 1:
|
||||
xy = xy.contiguous()
|
||||
batch_shape = xy.shape[:-1]
|
||||
xy = xy.reshape(-1, xy.shape[-1])
|
||||
x, y = xy.chunk(2, dim=-1)
|
||||
if out is None:
|
||||
out = torch.empty_like(x)
|
||||
else:
|
||||
out = out.reshape(-1, out.shape[-1])
|
||||
assert out.shape == x.shape
|
||||
assert out.stride(-1) == 1
|
||||
M, N = x.shape
|
||||
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
||||
with torch.cuda.device(x.device.index):
|
||||
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
|
||||
return out.reshape(*batch_shape, out.shape[-1])
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_N': 32}),
|
||||
triton.Config({'BLOCK_N': 64}),
|
||||
triton.Config({'BLOCK_N': 128}),
|
||||
triton.Config({'BLOCK_N': 256}),
|
||||
triton.Config({'BLOCK_N': 512}),
|
||||
triton.Config({'BLOCK_N': 1024}),
|
||||
],
|
||||
key=['ncols'],
|
||||
)
|
||||
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
|
||||
@triton.jit
|
||||
def _swiglu_bwd_kernel(
|
||||
X,
|
||||
Y,
|
||||
DOUT,
|
||||
OUT,
|
||||
DX,
|
||||
DY,
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_dout_row,
|
||||
stride_out_row,
|
||||
stride_dx_row,
|
||||
stride_dy_row,
|
||||
ncols,
|
||||
BLOCK_N: tl.constexpr,
|
||||
RECOMPUTE_OUTPUT: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
start_col = tl.program_id(1) * BLOCK_N
|
||||
X += row * stride_x_row
|
||||
Y += row * stride_y_row
|
||||
DOUT += row * stride_dout_row
|
||||
if RECOMPUTE_OUTPUT:
|
||||
OUT += row * stride_out_row
|
||||
DX += row * stride_dx_row
|
||||
DY += row * stride_dy_row
|
||||
cols = start_col + tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
||||
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
||||
x_sigmoid = tl.sigmoid(x)
|
||||
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
|
||||
dy = x * x_sigmoid * dout
|
||||
tl.store(DX + cols, dx, mask=cols < ncols)
|
||||
tl.store(DY + cols, dy, mask=cols < ncols)
|
||||
if RECOMPUTE_OUTPUT:
|
||||
out = x * x_sigmoid * y
|
||||
tl.store(OUT + cols, out, mask=cols < ncols)
|
||||
|
||||
|
||||
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
|
||||
if xy.stride(-1) != 1:
|
||||
xy = xy.contiguous()
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
batch_shape = xy.shape[:-1]
|
||||
xy = xy.reshape(-1, xy.shape[-1])
|
||||
x, y = xy.chunk(2, dim=-1)
|
||||
dout = dout.reshape(-1, dout.shape[-1])
|
||||
assert dout.shape == x.shape
|
||||
if dxy is None:
|
||||
dxy = torch.empty_like(xy)
|
||||
else:
|
||||
dxy = dxy.reshape(-1, dxy.shape[-1])
|
||||
assert dxy.shape == xy.shape
|
||||
dx, dy = dxy.chunk(2, dim=-1)
|
||||
assert dx.stride(-1) == 1
|
||||
assert dy.stride(-1) == 1
|
||||
if recompute_output:
|
||||
if out is None:
|
||||
out = torch.empty_like(x)
|
||||
else:
|
||||
out = out.reshape(-1, out.shape[-1])
|
||||
assert out.shape == x.shape
|
||||
assert out.stride(-1) == 1
|
||||
M, N = x.shape
|
||||
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
||||
with torch.cuda.device(x.device.index):
|
||||
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
|
||||
x.stride(0), y.stride(0), dout.stride(0),
|
||||
out.stride(0) if recompute_output else 0,
|
||||
dx.stride(0), dy.stride(0),
|
||||
N)
|
||||
if not recompute_output:
|
||||
return dxy.reshape(*batch_shape, dxy.shape[-1])
|
||||
else:
|
||||
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
|
||||
1115
Mamba/mamba-main/mamba_ssm/ops/triton/layer_norm.py
Normal file
1115
Mamba/mamba-main/mamba_ssm/ops/triton/layer_norm.py
Normal file
File diff suppressed because it is too large
Load Diff
437
Mamba/mamba-main/mamba_ssm/ops/triton/layernorm_gated.py
Normal file
437
Mamba/mamba-main/mamba_ssm/ops/triton/layernorm_gated.py
Normal file
@ -0,0 +1,437 @@
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
|
||||
dtype = x.dtype
|
||||
N = x.shape[-1]
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
|
||||
x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
|
||||
M, group_size, eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_kernel(
|
||||
X, # pointer to the input
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Y, # pointer to the output to be recomputed
|
||||
DY, # pointer to the output gradient
|
||||
DX, # pointer to the input gradient
|
||||
DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
DZ, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_z_row,
|
||||
stride_y_row,
|
||||
stride_dy_row,
|
||||
stride_dx_row,
|
||||
stride_dz_row,
|
||||
stride_dw_row,
|
||||
stride_db_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
rows_per_program,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
RECOMPUTE_OUTPUT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
row_block_id = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
row_start = row_block_id * rows_per_program
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
mask = cols < N
|
||||
X += row_start * stride_x_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row_start * stride_z_row + group * N
|
||||
DZ += row_start * stride_dz_row + group * N
|
||||
DY += row_start * stride_dy_row + group * N
|
||||
DX += row_start * stride_dx_row + group * N
|
||||
if RECOMPUTE_OUTPUT:
|
||||
Y += row_start * stride_y_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
|
||||
B += group * N
|
||||
b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
|
||||
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
if HAS_BIAS:
|
||||
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
row_end = min((row_block_id + 1) * rows_per_program, M)
|
||||
for row in range(row_start, row_end):
|
||||
# Load data to SRAM
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.load(Mean + row)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x_og = x
|
||||
x = x_og * z * tl.sigmoid(z)
|
||||
rstd = tl.load(Rstd + row)
|
||||
# Compute dx
|
||||
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
||||
z_sigmoid = tl.sigmoid(z)
|
||||
y = xhat * w + b if HAS_BIAS else xhat * w
|
||||
if RECOMPUTE_OUTPUT:
|
||||
tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
|
||||
dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
||||
tl.store(DZ + cols, dz, mask=mask)
|
||||
dy *= z * z_sigmoid
|
||||
else:
|
||||
if RECOMPUTE_OUTPUT:
|
||||
y = xhat * w + b if HAS_BIAS else xhat * w
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
wdy = w * dy
|
||||
c1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
if not IS_RMS_NORM:
|
||||
c2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * c1 + c2)) * rstd
|
||||
else:
|
||||
dx = (wdy - xhat * c1) * rstd
|
||||
dw += dy * xhat
|
||||
if HAS_BIAS:
|
||||
db += dy
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z_sigmoid = tl.sigmoid(z)
|
||||
dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
||||
tl.store(DZ + cols, dz, mask=mask)
|
||||
dx *= z * z_sigmoid
|
||||
# Write dx
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
|
||||
X += stride_x_row
|
||||
if HAS_Z:
|
||||
Z += stride_z_row
|
||||
DZ += stride_dz_row
|
||||
if RECOMPUTE_OUTPUT:
|
||||
Y += stride_y_row
|
||||
DY += stride_dy_row
|
||||
DX += stride_dx_row
|
||||
tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
|
||||
if HAS_BIAS:
|
||||
tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
|
||||
norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
assert dy.stride(-1) == 1
|
||||
assert dy.shape == (M, N)
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
dx = torch.empty_like(x)
|
||||
if dz is not None:
|
||||
assert z is not None
|
||||
assert dz.shape == z.shape
|
||||
assert dz.stride(-1) == 1
|
||||
else:
|
||||
dz = torch.empty_like(z) if z is not None else None
|
||||
if recompute_output:
|
||||
if out is None:
|
||||
out = torch.empty_like(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
||||
# If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
|
||||
# would limit the occupancy.
|
||||
nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
|
||||
_dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
|
||||
_db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
|
||||
rows_per_program = math.ceil(M / nrow_groups)
|
||||
grid = (nrow_groups, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
|
||||
dy, dx, _dw, _db, dz, mean, rstd,
|
||||
x.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
0 if not recompute_output else out.stride(0),
|
||||
dy.stride(0), dx.stride(0),
|
||||
dz.stride(0) if dz is not None else 0,
|
||||
_dw.stride(0),
|
||||
_db.stride(0) if _db is not None else 0,
|
||||
M, group_size, eps,
|
||||
rows_per_program,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
dw = _dw.sum(0).to(weight.dtype)
|
||||
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
||||
return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
x, weight, bias, mean, rstd, z = ctx.saved_tensors
|
||||
dy = dy.reshape(-1, dy.shape[-1])
|
||||
if dy.stride(-1) != 1:
|
||||
dy = dy.contiguous()
|
||||
assert dy.shape == x.shape
|
||||
dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
|
||||
ctx.norm_before_gate, ctx.is_rms_norm)
|
||||
return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
|
||||
|
||||
|
||||
def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
|
||||
|
||||
|
||||
def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
266
Mamba/mamba-main/mamba_ssm/ops/triton/selective_state_update.py
Normal file
266
Mamba/mamba-main/mamba_ssm/ops/triton/selective_state_update.py
Normal file
@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from mamba_ssm.ops.triton.softplus import softplus
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
@triton.jit
|
||||
#用于选择性状态更新
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
|
||||
# Matrix dimensions
|
||||
batch, nheads, dim, dstate, nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
|
||||
stride_x_batch, stride_x_head, stride_x_dim,
|
||||
stride_dt_batch, stride_dt_head, stride_dt_dim,
|
||||
stride_dt_bias_head, stride_dt_bias_dim,
|
||||
stride_A_head, stride_A_dim, stride_A_dstate,
|
||||
stride_B_batch, stride_B_group, stride_B_dstate,
|
||||
stride_C_batch, stride_C_group, stride_C_dstate,
|
||||
stride_D_head, stride_D_dim,
|
||||
stride_z_batch, stride_z_head, stride_z_dim,
|
||||
stride_out_batch, stride_out_head, stride_out_dim,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
if not TIE_HDIM:
|
||||
dB = B[None, :] * dt[:, None]
|
||||
else:
|
||||
dB = B * dt # vector of size (dstate,)
|
||||
state = state * dA + dB * x[:, None]
|
||||
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
out = torch.empty_like(x)
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
|
||||
else ((16, 4) if dstate <= 32 else
|
||||
((8, 4) if dstate <= 64 else
|
||||
((4, 4) if dstate <= 128 else
|
||||
((4, 8))))))
|
||||
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state, x, dt, dt_bias, A, B, C, D, z, out,
|
||||
batch, nheads, dim, dstate, nheads // ngroups,
|
||||
state.stride(0), state.stride(1), state.stride(2), state.stride(3),
|
||||
x.stride(0), x.stride(1), x.stride(2),
|
||||
dt.stride(0), dt.stride(1), dt.stride(2),
|
||||
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0), A.stride(1), A.stride(2),
|
||||
B.stride(0), B.stride(1), B.stride(2),
|
||||
C.stride(0), C.stride(1), C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0], z_strides[1], z_strides[2],
|
||||
out.stride(0), out.stride(1), out.stride(2),
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = F.softplus(dt) if dt_softplus else dt
|
||||
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
||||
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
||||
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).to(out.dtype)
|
||||
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
17
Mamba/mamba-main/mamba_ssm/ops/triton/softplus.py
Normal file
17
Mamba/mamba-main/mamba_ssm/ops/triton/softplus.py
Normal file
@ -0,0 +1,17 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
||||
|
||||
|
||||
if TRITON3:
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
else:
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
262
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_bmm.py
Normal file
262
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_bmm.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
"""We want triton==2.1.0 or 2.2.0 for this
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def init_to_zero(names):
|
||||
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
||||
],
|
||||
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr, b_ptr, out_ptr, seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen, chunk_size, K, ngroups,
|
||||
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
||||
stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
|
||||
stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_ch = tl.program_id(axis=2)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
|
||||
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if HAS_SEQ_IDX:
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
|
||||
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
||||
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
|
||||
],
|
||||
key=['chunk_size', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_bwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr, dout_ptr, db_ptr, res_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen, chunk_size, K, ngroups,
|
||||
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
||||
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
|
||||
stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
|
||||
stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
|
||||
# Meta-parameters
|
||||
dot_dtype: tl.constexpr,
|
||||
HAS_RESIDUAL: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_ch = tl.program_id(axis=2)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
|
||||
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_cs = tl.arange(0, BLOCK_SIZE_CS)
|
||||
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
|
||||
a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
|
||||
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
|
||||
a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
|
||||
acc += tl.dot(dout, a)
|
||||
dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
|
||||
a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if HAS_RESIDUAL:
|
||||
res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
|
||||
res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
|
||||
res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
|
||||
acc += res
|
||||
db = acc.to(db_ptr.dtype.element_ty)
|
||||
|
||||
db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
|
||||
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
|
||||
tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
# Check constraints.
|
||||
has_groups = a.dim() == 4
|
||||
if not has_groups:
|
||||
batch, seqlen, k = a.shape
|
||||
else:
|
||||
batch, seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if a.stride(-1) != 1 and a.stride(1) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(1) != 1:
|
||||
b = b.contiguous()
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
|
||||
device=a.device, dtype=out_dtype)
|
||||
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
||||
(tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
|
||||
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
|
||||
batch, nchunks if not has_groups else nchunks * ngroups)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a, b, out, seq_idx,
|
||||
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
||||
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
||||
b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
|
||||
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
causal,
|
||||
dot_dtype,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _bmm_chunk_bwd(a, dout, residual=None, out=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
Return:
|
||||
out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
|
||||
If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
|
||||
zeroed out before calling this function.
|
||||
"""
|
||||
# Check constraints.
|
||||
has_groups = a.dim() == 4
|
||||
if not has_groups:
|
||||
batch, seqlen, k = a.shape
|
||||
else:
|
||||
batch, seqlen, ngroups, k = a.shape
|
||||
nchunks, chunk_size = dout.shape[1], dout.shape[-1]
|
||||
if a.stride(-1) != 1 and a.stride(-2) != 1:
|
||||
a = a.contiguous()
|
||||
if dout.stride(-1) != 1 and dout.stride(-2) != 1:
|
||||
dout = dout.contiguous()
|
||||
if residual is not None:
|
||||
assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
|
||||
if residual.stride(-1) != 1 and residual.stride(1) != 1:
|
||||
residual = residual.contiguous()
|
||||
# Allocates output.
|
||||
if out is not None:
|
||||
assert out.shape == a.shape
|
||||
assert out.stride(-1) == 1 or out.stride(1) == 1
|
||||
else:
|
||||
out = torch.empty_like(a)
|
||||
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
|
||||
(tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
|
||||
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
|
||||
nchunks if not has_groups else nchunks * ngroups)
|
||||
residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
|
||||
residual.stride(-1))
|
||||
if residual is not None else (0, 0, 0, 0))
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_bwd_kernel[grid](
|
||||
a, dout, out, residual,
|
||||
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
||||
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
||||
dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
|
||||
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
|
||||
residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
|
||||
dot_dtype,
|
||||
HAS_RESIDUAL=residual is not None,
|
||||
)
|
||||
return out
|
||||
1829
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_chunk_scan.py
Normal file
1829
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_chunk_scan.py
Normal file
File diff suppressed because it is too large
Load Diff
868
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_chunk_state.py
Normal file
868
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_chunk_state.py
Normal file
@ -0,0 +1,868 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
"""We want triton==2.1.0 or 2.2.0 for this
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from mamba_ssm.ops.triton.softplus import softplus
|
||||
|
||||
|
||||
def init_to_zero(names):
|
||||
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_H': 1}),
|
||||
triton.Config({'BLOCK_SIZE_H': 2}),
|
||||
triton.Config({'BLOCK_SIZE_H': 4}),
|
||||
triton.Config({'BLOCK_SIZE_H': 8}),
|
||||
triton.Config({'BLOCK_SIZE_H': 16}),
|
||||
triton.Config({'BLOCK_SIZE_H': 32}),
|
||||
triton.Config({'BLOCK_SIZE_H': 64}),
|
||||
],
|
||||
key=['chunk_size', 'nheads'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
|
||||
# Matrix dimension
|
||||
batch, seqlen, nheads, chunk_size,
|
||||
dt_min, dt_max,
|
||||
# Strides
|
||||
stride_dt_batch, stride_dt_seqlen, stride_dt_head,
|
||||
stride_A_head,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=0)
|
||||
pid_c = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
||||
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
# As of Triton 2.2.0, tl.clamp is not available yet
|
||||
# dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
||||
dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
|
||||
tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
|
||||
],
|
||||
key=['chunk_size', 'nheads'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_bwd_kernel(
|
||||
# Pointers to matrices
|
||||
ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,
|
||||
ddt_ptr, dA_ptr, ddt_bias_ptr,
|
||||
# Matrix dimensions
|
||||
batch, seqlen, nheads, chunk_size,
|
||||
dt_min, dt_max,
|
||||
# Strides
|
||||
stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,
|
||||
stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,
|
||||
stride_dt_batch, stride_dt_seqlen, stride_dt_head,
|
||||
stride_A_head,
|
||||
stride_dt_bias_head,
|
||||
stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,
|
||||
stride_dA_head,
|
||||
stride_ddt_bias_head,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=0)
|
||||
pid_c = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
||||
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
||||
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)
|
||||
ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)
|
||||
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
|
||||
ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
||||
ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
ddt = ddA * A[:, None] + ddt_out
|
||||
dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt_presoftplus = dt
|
||||
dt = softplus(dt)
|
||||
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
||||
# As of Triton 2.2.0, tl.clamp is not available yet
|
||||
# dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
||||
dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
|
||||
ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)
|
||||
ddt = tl.where(clamp_mask, 0.0, ddt)
|
||||
if DT_SOFTPLUS:
|
||||
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
||||
tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))
|
||||
dA = tl.sum(ddA * dt, axis=1)
|
||||
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
||||
if HAS_DT_BIAS:
|
||||
ddt_bias = tl.sum(ddt, axis=1)
|
||||
tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
hdim, dstate, chunk_size,
|
||||
batch, seqlen, nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
||||
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
||||
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
||||
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
# Meta-parameters
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
|
||||
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
|
||||
if not HAS_SEQ_IDX:
|
||||
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
||||
else:
|
||||
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
|
||||
],
|
||||
key=['chunk_size', 'hdim', 'dstate'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_bwd_dx_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,
|
||||
dx_ptr, ddt_ptr, ddA_cumsum_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size, hdim, dstate,
|
||||
batch, seqlen, nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
||||
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
||||
stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
||||
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
||||
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
|
||||
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
|
||||
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
||||
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
|
||||
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dstates = dstates.to(b_ptr.dtype.element_ty)
|
||||
acc = tl.dot(b, dstates)
|
||||
else:
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dstates = dstates.to(b_ptr.dtype.element_ty)
|
||||
acc += tl.dot(b, dstates)
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
||||
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
||||
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
||||
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
||||
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
||||
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
||||
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
||||
ddt = tl.sum(acc * x, axis=1)
|
||||
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
||||
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
||||
ddA_cs = -(ddt * dt_m)
|
||||
ddA_cs_last = -tl.sum(ddA_cs)
|
||||
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
||||
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
||||
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
||||
|
||||
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
||||
dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
|
||||
dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
|
||||
tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
],
|
||||
key=['chunk_size', 'dstate', 'hdim'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_bwd_db_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
|
||||
db_ptr, ddA_cumsum_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size, dstate, hdim,
|
||||
batch, seqlen, nheads, nheads_per_program, ngroups,
|
||||
# Strides
|
||||
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
||||
stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
||||
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
||||
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,
|
||||
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
|
||||
# Meta-parameters
|
||||
HAS_DDA_CS: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_sg = tl.program_id(axis=2)
|
||||
pid_s = pid_sg // ngroups
|
||||
pid_g = pid_sg - pid_s * ngroups
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
||||
db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split
|
||||
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
||||
if HAS_DDA_CS:
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head
|
||||
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)
|
||||
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
||||
if HAS_DDA_CS:
|
||||
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)
|
||||
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
if HAS_DDA_CS:
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
||||
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
|
||||
for h in range(nheads_iter):
|
||||
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
|
||||
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
||||
db = tl.dot(x, dstates)
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
||||
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
||||
if not HAS_SEQ_IDX:
|
||||
scale = tl.exp(dA_cs_last - dA_cs_m)
|
||||
else:
|
||||
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
||||
db *= (scale * dt_m)[:, None]
|
||||
if HAS_DDA_CS:
|
||||
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
||||
ddA_cs = tl.sum(db * b, axis=1)
|
||||
tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
|
||||
acc += db
|
||||
x_ptrs += stride_x_head
|
||||
dstates_ptrs += stride_states_head
|
||||
dt_ptrs += stride_dt_head
|
||||
dA_cumsum_ptr += stride_dA_cs_head
|
||||
dA_cumsum_ptrs += stride_dA_cs_head
|
||||
if HAS_DDA_CS:
|
||||
ddA_cumsum_ptrs += stride_ddA_cs_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
# if HAS_SEQ_IDX:
|
||||
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
||||
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
||||
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)
|
||||
tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
||||
],
|
||||
key=['chunk_size', 'hdim', 'dstate'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_bwd_ddAcs_stable_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
|
||||
ddA_cumsum_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size, hdim, dstate,
|
||||
batch, seqlen, nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
||||
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
||||
stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
|
||||
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
|
||||
# Meta-parameters
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
|
||||
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dstates = dstates.to(b_ptr.dtype.element_ty)
|
||||
acc = tl.dot(b, dstates)
|
||||
else:
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dstates = dstates.to(b_ptr.dtype.element_ty)
|
||||
acc += tl.dot(b, dstates)
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
||||
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
if not HAS_SEQ_IDX:
|
||||
scale = tl.exp(dA_cs_last - dA_cs_m)
|
||||
else:
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
||||
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
||||
acc *= scale[:, None]
|
||||
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
||||
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
||||
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
||||
ddt = tl.sum(acc * x, axis=1)
|
||||
# ddA_cs = -(ddt * dt_m)
|
||||
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
||||
# then call torch.cumsum outside this kernel.
|
||||
# ddA_cs = tl.cumsum(ddt * dt_m)
|
||||
ddA_cs = ddt * dt_m
|
||||
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
||||
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
||||
tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
|
||||
batch, seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads,)
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads,)
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
|
||||
dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
|
||||
grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt, A, dt_bias, dt_out, dA_cumsum,
|
||||
batch, seqlen, nheads, chunk_size,
|
||||
dt_limit[0], dt_limit[1],
|
||||
dt.stride(0), dt.stride(1), dt.stride(2),
|
||||
A.stride(0),
|
||||
dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
|
||||
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
||||
dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None):
|
||||
batch, seqlen, nheads = dt.shape
|
||||
_, _, nchunks, chunk_size = ddA.shape
|
||||
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert A.shape == (nheads,)
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads,)
|
||||
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
||||
else:
|
||||
ddt_bias = None
|
||||
if ddt is not None:
|
||||
assert ddt.shape == dt.shape
|
||||
else:
|
||||
ddt = torch.empty_like(dt)
|
||||
dA = torch.empty_like(A, dtype=torch.float32)
|
||||
grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
||||
ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,
|
||||
batch, seqlen, nheads, chunk_size,
|
||||
dt_limit[0], dt_limit[1],
|
||||
ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),
|
||||
ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),
|
||||
dt.stride(0), dt.stride(1), dt.stride(2),
|
||||
A.stride(0),
|
||||
dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
ddt.stride(0), ddt.stride(1), ddt.stride(2),
|
||||
dA.stride(0),
|
||||
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
||||
dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return ddt, dA, ddt_bias
|
||||
|
||||
|
||||
def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if states is not None:
|
||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
|
||||
batch * nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x, B, states, dt, dA_cumsum, seq_idx,
|
||||
headdim, dstate, chunk_size,
|
||||
batch, seqlen, nheads // ngroups,
|
||||
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
|
||||
states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
|
||||
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
||||
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
if dx is not None:
|
||||
assert dx.shape == x.shape
|
||||
else:
|
||||
dx = torch.empty_like(x)
|
||||
ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
|
||||
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32)
|
||||
grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
|
||||
batch * nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_bwd_dx_kernel[grid_dx](
|
||||
x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,
|
||||
chunk_size, headdim, dstate,
|
||||
batch, seqlen, nheads // ngroups,
|
||||
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
|
||||
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
||||
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
||||
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
||||
dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
|
||||
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
|
||||
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
)
|
||||
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
||||
|
||||
|
||||
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
dstate = dstates.shape[-1]
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if B is not None:
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
||||
# Use torch.empty since the Triton kernel will call init_to_zero
|
||||
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
|
||||
ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
|
||||
else:
|
||||
B_strides = (0, 0, 0, 0)
|
||||
ddA_cumsum = None
|
||||
ddA_cumsum_strides = (0, 0, 0, 0)
|
||||
nheads_ngroups_ratio = nheads // ngroups
|
||||
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
||||
nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
|
||||
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
||||
dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)
|
||||
grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
|
||||
batch * nchunks, nsplits * ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_bwd_db_kernel[grid_db](
|
||||
x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,
|
||||
chunk_size, dstate, headdim,
|
||||
batch, seqlen, nheads, nheads_per_program, ngroups,
|
||||
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
||||
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
||||
*B_strides,
|
||||
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
||||
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),
|
||||
*ddA_cumsum_strides,
|
||||
HAS_DDA_CS=ddA_cumsum is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
||||
)
|
||||
dB = dB.sum(2)
|
||||
if ddA_cumsum is not None:
|
||||
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
||||
# to the state of the chunk.
|
||||
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
||||
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
||||
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
||||
return dB if B is None else (dB, ddA_cumsum)
|
||||
|
||||
|
||||
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
# Use torch.empty since the Triton kernel will call init_to_zero
|
||||
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
|
||||
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
|
||||
batch * nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
||||
x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,
|
||||
chunk_size, headdim, dstate,
|
||||
batch, seqlen, nheads // ngroups,
|
||||
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
|
||||
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
||||
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
||||
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
)
|
||||
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
||||
return ddA_cumsum
|
||||
|
||||
|
||||
class ChunkStateFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
assert seqlen <= nchunks * chunk_size
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
||||
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dstates):
|
||||
B, x, dt, dA_cumsum = ctx.saved_tensors
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
if dstates.stride(-1) != 1:
|
||||
dstates = dstates.contiguous()
|
||||
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
||||
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
||||
dB = dB.to(B.dtype)
|
||||
return dB, dx, ddt, ddA_cumsum, None
|
||||
|
||||
|
||||
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
||||
"""
|
||||
Argument:
|
||||
B: (batch, seqlen, ngroups, headdim)
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, nheads, nchunks, chunk_size)
|
||||
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
||||
Return:
|
||||
states: (batch, nchunks, nheads, headdim, dstate)
|
||||
"""
|
||||
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
||||
|
||||
|
||||
def chunk_state_ref(B, x, dt, dA_cumsum):
|
||||
"""
|
||||
Argument:
|
||||
B: (batch, seqlen, ngroups, headdim)
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, nheads, nchunks, chunk_size)
|
||||
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
||||
Return:
|
||||
states: (batch, nchunks, nheads, headdim, dstate)
|
||||
"""
|
||||
# Check constraints.
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
dstate = B.shape[-1]
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
assert seqlen <= nchunks * chunk_size
|
||||
assert x.shape == (batch, seqlen, nheads, headdim)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
ngroups = B.shape[2]
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
if seqlen < nchunks * chunk_size:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
||||
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
||||
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
||||
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
||||
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
||||
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
|
||||
963
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_combined.py
Normal file
963
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_combined.py
Normal file
@ -0,0 +1,963 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
"""We want triton==2.1.0 or 2.2.0 for this
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import math
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
try:
|
||||
from causal_conv1d import causal_conv1d_fn
|
||||
import causal_conv1d_cuda
|
||||
except ImportError:
|
||||
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
||||
|
||||
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
||||
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
||||
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
||||
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
||||
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
|
||||
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
||||
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
|
||||
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
||||
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
||||
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
||||
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
||||
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
||||
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
||||
from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
def init_to_zero(names):
|
||||
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
|
||||
],
|
||||
key=['chunk_size', 'hdim', 'dstate'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
|
||||
b_ptr, dstates_ptr,
|
||||
dx_ptr, ddt_ptr, dD_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size, hdim, dstate,
|
||||
batch, seqlen, nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
|
||||
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
|
||||
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
|
||||
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
stride_D_head,
|
||||
stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
|
||||
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
|
||||
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
|
||||
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
|
||||
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
|
||||
# Meta-parameters
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
||||
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
||||
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
if not HAS_SEQ_IDX:
|
||||
scale = tl.exp(dA_cs_last - dA_cs_m)
|
||||
else:
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
||||
seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
||||
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
||||
# Unexpected mma -> mma layout conversion
|
||||
# Triton 2.2.0 fixes this
|
||||
offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)
|
||||
dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)
|
||||
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dstates = dstates.to(b_ptr.dtype.element_ty)
|
||||
acc = tl.dot(b, dstates) * scale[:, None]
|
||||
else:
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)
|
||||
dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dstates = dstates.to(b_ptr.dtype.element_ty)
|
||||
acc += tl.dot(b, dstates)
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
||||
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
||||
acc *= scale[:, None]
|
||||
|
||||
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
||||
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
||||
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
||||
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
||||
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
||||
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
||||
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
|
||||
dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = chunk_size_limit
|
||||
K_MIN = pid_m * BLOCK_SIZE_M
|
||||
cb_ptrs += K_MIN * stride_cb_csize_k
|
||||
dout_ptrs += K_MIN * stride_dout_seqlen
|
||||
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
||||
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
||||
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
||||
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
||||
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
|
||||
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
|
||||
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
||||
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
||||
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
||||
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
||||
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
||||
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(dout_ptr.dtype.element_ty)
|
||||
acc += tl.dot(cb, dout)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
||||
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
||||
dx = acc * dt_m[:, None]
|
||||
dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
|
||||
dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
|
||||
if HAS_D:
|
||||
dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
|
||||
dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
dx += dout_res * D
|
||||
tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
|
||||
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
||||
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
|
||||
if D_HAS_HDIM:
|
||||
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
||||
dD = tl.sum(dout_res * x, axis=0)
|
||||
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
||||
else:
|
||||
dD = tl.sum(dout_res * x)
|
||||
tl.store(dD_ptr, dD)
|
||||
ddt = tl.sum(acc * x, axis=1)
|
||||
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
||||
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
||||
|
||||
|
||||
def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert dout.shape == x.shape
|
||||
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
assert D.stride(-1) == 1
|
||||
BLOCK_SIZE_min = 32
|
||||
dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
|
||||
headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
|
||||
else:
|
||||
dD = None
|
||||
dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
||||
if D is not None else (0, 0, 0, 0, 0))
|
||||
if dx is None:
|
||||
dx = torch.empty_like(x)
|
||||
else:
|
||||
assert dx.shape == x.shape
|
||||
ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
|
||||
grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
|
||||
batch * nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
||||
x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,
|
||||
chunk_size, headdim, dstate,
|
||||
batch, seqlen, nheads // ngroups,
|
||||
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
|
||||
CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),
|
||||
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
|
||||
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
|
||||
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
D.stride(0) if D is not None else 0,
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
|
||||
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
|
||||
dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
|
||||
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
|
||||
dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
|
||||
D is not None,
|
||||
D.dim() == 2 if D is not None else True,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
IS_TRITON_22=TRITON_22
|
||||
)
|
||||
if D is not None:
|
||||
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
|
||||
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
||||
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
||||
if D.dim() == 1:
|
||||
dD = rearrange(dD, "h 1 -> h")
|
||||
return dx, ddt.to(dtype=dt.dtype), dD
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert x.shape == (batch, seqlen, nheads, headdim)
|
||||
assert dt.shape == (batch, seqlen, nheads)
|
||||
assert A.shape == (nheads,)
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
||||
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
||||
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
||||
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
|
||||
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
||||
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
||||
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
||||
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
||||
states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
|
||||
seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
|
||||
states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
|
||||
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
||||
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
||||
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
||||
out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
|
||||
return out, out_x, dt, dA_cumsum, states, final_states
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,
|
||||
dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert dout.shape == (batch, seqlen, nheads, headdim)
|
||||
assert dt.shape == (batch, seqlen, nheads)
|
||||
assert A.shape == (nheads,)
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
assert out.shape == x.shape
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if dx is not None:
|
||||
assert dx.shape == x.shape
|
||||
if dB is not None:
|
||||
assert dB.shape == B.shape
|
||||
dB_given = dB
|
||||
else:
|
||||
dB_given = torch.empty_like(B)
|
||||
if dC is not None:
|
||||
assert dC.shape == C.shape
|
||||
dC_given = dC
|
||||
else:
|
||||
dC_given = torch.empty_like(C)
|
||||
if dz is not None:
|
||||
assert z is not None
|
||||
assert dz.shape == z.shape
|
||||
if ddt is not None:
|
||||
assert ddt.shape == dt.shape
|
||||
ddt_given = ddt
|
||||
else:
|
||||
ddt_given = torch.empty_like(dt)
|
||||
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
||||
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
||||
dt_in = dt.clone()
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit)
|
||||
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
||||
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
||||
states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
|
||||
seq_idx=seq_idx, chunk_size=chunk_size)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
if z is not None:
|
||||
dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)
|
||||
outz = rest[0] if recompute_output else out
|
||||
else:
|
||||
dz = None
|
||||
outz = out
|
||||
dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)
|
||||
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
||||
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
||||
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
||||
# will be used in matmul in the next kernels.
|
||||
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum[:, :, :, -1],
|
||||
rearrange(dstates, "... p n -> ... (p n)"),
|
||||
dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None,
|
||||
seq_idx=seq_idx,
|
||||
has_initial_states=initial_states is not None,
|
||||
dstates_dtype=x.dtype,
|
||||
states_dtype=x.dtype,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
||||
# gradient to the final states at index (nchunks - 1)
|
||||
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
||||
# The final states is not stored.
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
||||
dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None
|
||||
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
|
||||
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
||||
dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)
|
||||
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
||||
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)
|
||||
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
||||
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
||||
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
||||
dCB = dCB.to(CB.dtype)
|
||||
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
||||
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
||||
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
||||
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
||||
if z is None:
|
||||
dD = dD_from_x
|
||||
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
||||
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
||||
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
||||
# be a lot of underflow.
|
||||
|
||||
# This is already done as part of bwd_dC kernel
|
||||
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
||||
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
||||
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
||||
# This is already done as part of bwd_dB kernel
|
||||
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
||||
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
||||
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
||||
ddA += ddA_next + ddA_prev
|
||||
|
||||
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)
|
||||
|
||||
# These 2 lines are just to test ddt and dA being computed by old code
|
||||
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
||||
# ddt_given.copy_(ddt)
|
||||
|
||||
return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)
|
||||
return return_vals if not recompute_output else (*return_vals, outz)
|
||||
|
||||
|
||||
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
||||
"""
|
||||
Argument:
|
||||
dout: (batch, seqlen, nheads, headdim)
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
||||
A: (nheads) or (dim, dstate)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
import selective_scan
|
||||
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
chunk_size = dt.shape[-1]
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
x = rearrange(x, "b l h p -> b (h p) l")
|
||||
squeeze_dt = dt.dim() == 4
|
||||
if dt.dim() == 4:
|
||||
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
||||
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
||||
squeeze_A = A.dim() == 1
|
||||
if A.dim() == 1:
|
||||
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
||||
else:
|
||||
A = A.to(dtype=torch.float32)
|
||||
B = rearrange(B, "b l g n -> b g n l")
|
||||
C = rearrange(C, "b l g n -> b g n l")
|
||||
if D is not None:
|
||||
if D.dim() == 2:
|
||||
D = rearrange(D, "h p -> (h p)")
|
||||
else:
|
||||
D = repeat(D, "h -> (h p)", p=headdim)
|
||||
if z is not None:
|
||||
z = rearrange(z, "b l h p -> b (h p) l")
|
||||
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if dt.stride(-1) != 1:
|
||||
dt = dt.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
_, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)
|
||||
if z is not None:
|
||||
out = rest[0]
|
||||
else:
|
||||
out = None
|
||||
|
||||
dout = rearrange(dout, "b l h p -> b (h p) l")
|
||||
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan with the backward of chunk).
|
||||
# Here we just pass in None and dz will be allocated in the C++ code.
|
||||
_, ddt, dA, *rest = selective_scan.bwd(
|
||||
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,
|
||||
False # option to recompute out_z, not used here
|
||||
)
|
||||
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
||||
if squeeze_dt:
|
||||
ddt = ddt.float().sum(dim=2)
|
||||
if squeeze_A:
|
||||
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
||||
return ddt, dA
|
||||
|
||||
|
||||
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
|
||||
ctx.dt_dtype = dt.dtype
|
||||
out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
|
||||
ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)
|
||||
ctx.dt_softplus = dt_softplus
|
||||
ctx.chunk_size = chunk_size
|
||||
ctx.dt_limit = dt_limit
|
||||
ctx.return_final_states = return_final_states
|
||||
return out if not return_final_states else (out, final_states)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
|
||||
dfinal_states = args[0] if ctx.return_final_states else None
|
||||
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
|
||||
return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
seq_idx: (batch, seqlen)
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
|
||||
|
||||
|
||||
def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
dstate = B.shape[-1]
|
||||
if seqlen % chunk_size != 0:
|
||||
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
||||
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
||||
dt = dt.float() # We want high precision for this before cumsum
|
||||
if dt_bias is not None:
|
||||
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
||||
if dt_softplus:
|
||||
dt = F.softplus(dt)
|
||||
dA = dt * rearrange(A, "h -> h 1 1")
|
||||
dA = dt * rearrange(A, "h -> h 1 1")
|
||||
dA_cumsum = torch.cumsum(dA, dim=-1)
|
||||
# 1. Compute the state for each chunk
|
||||
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
||||
# 2. Pass the state to all the chunks by weighted cumsum.
|
||||
states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
|
||||
"... (p n) -> ... p n", n=dstate)
|
||||
# 3. Compute the output for each chunk
|
||||
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
||||
return out
|
||||
|
||||
|
||||
def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
dstate = B.shape[-1]
|
||||
if seqlen % chunk_size != 0:
|
||||
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
||||
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
||||
dt = dt.float() # We want high precision for this before cumsum
|
||||
if dt_bias is not None:
|
||||
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
||||
if dt_softplus:
|
||||
dt = F.softplus(dt)
|
||||
dA = dt * rearrange(A, "h -> h 1 1")
|
||||
dA_cumsum = torch.cumsum(dA, dim=-1)
|
||||
# 1. Compute the state for each chunk
|
||||
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
||||
states_dtype = states.dtype
|
||||
if states.dtype not in [torch.float32, torch.float64]:
|
||||
states = states.to(torch.float32)
|
||||
# 2. Pass the state to all the chunks by weighted cumsum.
|
||||
# state_passing_ref is much less numerically stable
|
||||
states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
|
||||
"... (p n) -> ... p n", n=dstate)
|
||||
states = states.to(states_dtype)
|
||||
# 3. Compute the output for each chunk
|
||||
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
||||
return out
|
||||
|
||||
|
||||
def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
||||
A: (nheads) or (dim, dstate)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,) or (nheads, headdim)
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
x = rearrange(x, "b l h p -> b (h p) l")
|
||||
if dt.dim() == 3:
|
||||
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
||||
dt = rearrange(dt, "b l h p -> b (h p) l")
|
||||
if A.dim() == 1:
|
||||
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
||||
else:
|
||||
A = A.to(dtype=torch.float32)
|
||||
B = rearrange(B, "b l g n -> b g n l")
|
||||
C = rearrange(C, "b l g n -> b g n l")
|
||||
if D is not None:
|
||||
if D.dim() == 2:
|
||||
D = rearrange(D, "h p -> (h p)")
|
||||
else:
|
||||
D = repeat(D, "h -> (h p)", p=headdim)
|
||||
if z is not None:
|
||||
z = rearrange(z, "b l h p -> b (h p) l")
|
||||
if dt_bias is not None:
|
||||
if dt_bias.dim() == 1:
|
||||
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
||||
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
||||
if dt_limit != (0.0, float("inf")):
|
||||
if dt_bias is not None:
|
||||
dt = dt + rearrange(dt_bias, "d -> d 1")
|
||||
if dt_softplus:
|
||||
dt = F.softplus(dt)
|
||||
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
||||
dt_bias = None
|
||||
dt_softplus = None
|
||||
out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)
|
||||
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
||||
|
||||
|
||||
def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,
|
||||
dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")),
|
||||
activation="silu", headdim=None, ngroups=1):
|
||||
"""
|
||||
Argument:
|
||||
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
||||
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
||||
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
||||
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
||||
A: (nheads)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, dim)
|
||||
dt_bias: (nheads) or (nheads, headdim)
|
||||
headdim: if D is 1D and z is None, headdim must be passed in
|
||||
Return:
|
||||
out: (batch, seqlen, dim)
|
||||
"""
|
||||
batch, seqlen, nheads = dt.shape[:3]
|
||||
assert nheads % ngroups == 0
|
||||
if z is not None:
|
||||
dim = z.shape[-1]
|
||||
assert dim % nheads == 0
|
||||
headdim = dim // nheads
|
||||
else:
|
||||
if D.dim() == 1:
|
||||
assert headdim is not None
|
||||
else:
|
||||
headdim = D.shape[1]
|
||||
dim = nheads * headdim
|
||||
xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
|
||||
"b d s -> b s d")
|
||||
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
||||
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
||||
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
||||
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
||||
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
||||
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
||||
out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
|
||||
return rearrange(out, "b s h p -> b s (h p)")
|
||||
|
||||
|
||||
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
|
||||
rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
|
||||
ngroups=1, norm_before_gate=True):
|
||||
assert activation in [None, "silu", "swish"]
|
||||
if D.dim() == 1:
|
||||
assert headdim is not None
|
||||
nheads, = D.shape
|
||||
else:
|
||||
nheads, headdim = D.shape
|
||||
batch, seqlen, _ = zxbcdt.shape
|
||||
dim = nheads * headdim
|
||||
assert nheads % ngroups == 0
|
||||
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
||||
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
||||
assert d_nonssm >= 0
|
||||
assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
|
||||
assert dt_bias.shape == (nheads,)
|
||||
assert A.shape == (nheads,)
|
||||
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
|
||||
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
||||
xBC_conv = rearrange(
|
||||
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
|
||||
conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
|
||||
"b d s -> b s d"
|
||||
)
|
||||
x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
||||
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
||||
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
||||
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
||||
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
||||
if rmsnorm_weight is None:
|
||||
out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
|
||||
out = rearrange(out, "b s h p -> b s (h p)")
|
||||
rstd = None
|
||||
if d_nonssm > 0:
|
||||
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
||||
else:
|
||||
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
|
||||
# reshape input data into 2D tensor
|
||||
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
||||
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
||||
rmsnorm_weight = rmsnorm_weight.contiguous()
|
||||
if d_nonssm == 0:
|
||||
out = None
|
||||
else:
|
||||
out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
|
||||
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
||||
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
||||
out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
|
||||
group_size=dim // ngroups,
|
||||
norm_before_gate=norm_before_gate, is_rms_norm=True)
|
||||
if d_nonssm == 0:
|
||||
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
||||
else:
|
||||
out = out01
|
||||
ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
|
||||
if outproj_weight is not None:
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
||||
outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
|
||||
out = F.linear(out, outproj_weight, outproj_bias)
|
||||
else:
|
||||
assert outproj_bias is None
|
||||
ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
|
||||
out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
|
||||
ctx.dt_limit = dt_limit
|
||||
ctx.return_final_states = return_final_states
|
||||
ctx.activation = activation
|
||||
ctx.rmsnorm_eps = rmsnorm_eps
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.chunk_size = chunk_size
|
||||
ctx.headdim = headdim
|
||||
ctx.ngroups = ngroups
|
||||
return out if not return_final_states else (out, final_states)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, dout, *args):
|
||||
zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
|
||||
dfinal_states = args[0] if ctx.return_final_states else None
|
||||
headdim = ctx.headdim
|
||||
nheads = D.shape[0]
|
||||
dim = nheads * headdim
|
||||
assert nheads % ctx.ngroups == 0
|
||||
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
||||
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
||||
assert d_nonssm >= 0
|
||||
recompute_output = outproj_weight is not None
|
||||
if recompute_output:
|
||||
out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
|
||||
out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
|
||||
zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
|
||||
# Recompute x, B, C
|
||||
xBC_conv = rearrange(
|
||||
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
|
||||
conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
|
||||
"b d s -> b s d"
|
||||
)
|
||||
x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
|
||||
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
||||
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
||||
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
||||
dzxbcdt = torch.empty_like(zxbcdt)
|
||||
dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
|
||||
dxBC = torch.empty_like(xBC)
|
||||
dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
|
||||
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
||||
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
||||
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
||||
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
||||
if outproj_weight is not None:
|
||||
dout_og = dout
|
||||
dout = F.linear(dout, outproj_weight.t())
|
||||
if d_nonssm > 0:
|
||||
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
||||
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
||||
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
||||
if rmsnorm_weight is None:
|
||||
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
||||
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
|
||||
dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
|
||||
)
|
||||
out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
||||
drmsnorm_weight = None
|
||||
else:
|
||||
batch = dout.shape[0]
|
||||
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
||||
dz = rearrange(dz, "b l d -> (b l) d")
|
||||
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
||||
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
||||
out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
|
||||
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
|
||||
out_for_linear = out_recompute if recompute_output else None
|
||||
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
||||
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
|
||||
dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
|
||||
)
|
||||
|
||||
if outproj_weight is not None:
|
||||
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
||||
doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
||||
else:
|
||||
doutproj_weight, doutproj_bias = None, None
|
||||
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
||||
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
||||
rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
|
||||
rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
|
||||
)
|
||||
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
||||
return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
|
||||
|
||||
|
||||
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
|
||||
"""
|
||||
Argument:
|
||||
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
||||
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
||||
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
||||
dt_bias: (nheads,)
|
||||
A: (nheads)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
seq_idx: (batch, seqlen), int32
|
||||
rmsnorm_weight: (dim,)
|
||||
outproj_weight: (out_dim, dim)
|
||||
outproj_bias: (out_dim,)
|
||||
headdim: if D is 1D, headdim must be passed in
|
||||
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
||||
Return:
|
||||
out: (batch, seqlen, dim)
|
||||
"""
|
||||
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
|
||||
|
||||
|
||||
def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
|
||||
"""
|
||||
Argument:
|
||||
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
||||
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
||||
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
||||
dt_bias: (nheads,)
|
||||
A: (nheads)
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
rmsnorm_weight: (dim,)
|
||||
outproj_weight: (out_dim, dim)
|
||||
outproj_bias: (out_dim,)
|
||||
headdim: if D is 1D, headdim must be passed in
|
||||
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
||||
Return:
|
||||
out: (batch, seqlen, dim)
|
||||
"""
|
||||
if D.dim() == 1:
|
||||
assert headdim is not None
|
||||
nheads, = D.shape
|
||||
else:
|
||||
nheads, headdim = D.shape
|
||||
assert nheads % ngroups == 0
|
||||
batch, seqlen, _ = zxbcdt.shape
|
||||
dim = nheads * headdim
|
||||
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
||||
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
||||
assert dt_bias.shape == (nheads,)
|
||||
assert A.shape == (nheads,)
|
||||
if rmsnorm_weight is not None:
|
||||
assert rmsnorm_weight.shape == (dim,)
|
||||
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
||||
xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
|
||||
"b d s -> b s d")
|
||||
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
||||
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
||||
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
||||
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
||||
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
||||
out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),
|
||||
z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)
|
||||
out = rearrange(out, "b s h p -> b s (h p)")
|
||||
if rmsnorm_weight is not None:
|
||||
out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps,
|
||||
norm_before_gate=norm_before_gate)
|
||||
if outproj_weight is not None:
|
||||
out = F.linear(out, outproj_weight, outproj_bias)
|
||||
return out
|
||||
|
||||
348
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_state_passing.py
Normal file
348
Mamba/mamba-main/mamba_ssm/ops/triton/ssd_state_passing.py
Normal file
@ -0,0 +1,348 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
|
||||
"""We want triton==2.1.0 or 2.2.0 for this
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64}),
|
||||
triton.Config({'BLOCK_SIZE': 128}),
|
||||
triton.Config({'BLOCK_SIZE': 256}),
|
||||
triton.Config({'BLOCK_SIZE': 512}),
|
||||
triton.Config({'BLOCK_SIZE': 1024}),
|
||||
triton.Config({'BLOCK_SIZE': 2048}),
|
||||
],
|
||||
key=['dim'],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
dim, nchunks, seqlen, chunk_size,
|
||||
# Strides
|
||||
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
|
||||
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
|
||||
stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
|
||||
stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
else:
|
||||
initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
seq_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale = tl.exp(dA_cs)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
||||
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
||||
seq_idx = seq_idx_new
|
||||
states = scale * states + new_states
|
||||
if c < nchunks - 1:
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
else:
|
||||
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64}),
|
||||
triton.Config({'BLOCK_SIZE': 128}),
|
||||
triton.Config({'BLOCK_SIZE': 256}),
|
||||
triton.Config({'BLOCK_SIZE': 512}),
|
||||
triton.Config({'BLOCK_SIZE': 1024}),
|
||||
triton.Config({'BLOCK_SIZE': 2048}),
|
||||
],
|
||||
key=['dim'],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_bwd_kernel(
|
||||
# Pointers to matrices
|
||||
dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
|
||||
dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
|
||||
# Matrix dimensions
|
||||
dim, nchunks, seqlen, chunk_size,
|
||||
# Strides
|
||||
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
|
||||
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
|
||||
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
|
||||
stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
|
||||
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
||||
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
|
||||
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
|
||||
stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
|
||||
# Meta-parameters
|
||||
CONVERT_STATES: tl.constexpr,
|
||||
HAS_DFINAL_STATES: tl.constexpr,
|
||||
HAS_DINITSTATES: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
|
||||
ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
|
||||
dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
|
||||
if CONVERT_STATES:
|
||||
states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
|
||||
if HAS_DFINAL_STATES:
|
||||
dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
|
||||
if HAS_DINITSTATES:
|
||||
dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
dout_ptrs = dout_ptr + offs_m * stride_dout_dim
|
||||
if CONVERT_STATES:
|
||||
states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
|
||||
|
||||
if HAS_DFINAL_STATES:
|
||||
dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
else:
|
||||
dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
|
||||
dstates_ptrs -= stride_dstates_chunk
|
||||
for c in range(nchunks - 1):
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale = tl.exp(dA_cs)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
|
||||
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
||||
seq_idx = seq_idx_new
|
||||
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if CONVERT_STATES:
|
||||
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
|
||||
ddA = tl.sum(out * dstates) * scale
|
||||
tl.store(ddA_cs_ptr, ddA)
|
||||
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dstates = scale * dstates + dout
|
||||
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
|
||||
dout_ptrs -= stride_dout_chunk
|
||||
dstates_ptrs -= stride_dstates_chunk
|
||||
dA_cs_ptr -= stride_dA_cs_chunk
|
||||
ddA_cs_ptr -= stride_ddA_cs_chunk
|
||||
out_ptrs -= stride_out_chunk
|
||||
if CONVERT_STATES:
|
||||
states_converted_ptrs -= stride_out_chunk
|
||||
if CONVERT_STATES:
|
||||
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
|
||||
if not HAS_DINITSTATES:
|
||||
tl.store(ddA_cs_ptr, 0.0)
|
||||
else:
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale = tl.exp(dA_cs)
|
||||
if HAS_SEQ_IDX:
|
||||
scale = tl.where(seq_idx == 0, scale, 0.0)
|
||||
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
ddA = tl.sum(out * dstates) * scale
|
||||
tl.store(ddA_cs_ptr, ddA)
|
||||
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dstates = scale * dstates + dout
|
||||
tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
|
||||
|
||||
|
||||
def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
|
||||
out_dtype=None):
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, dim)
|
||||
if seq_idx is not None:
|
||||
assert chunk_size is not None
|
||||
seqlen = seq_idx.shape[-1]
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
||||
final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
|
||||
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
|
||||
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
||||
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
|
||||
final_states.stride(0), final_states.stride(1), final_states.stride(2),
|
||||
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
|
||||
*((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
|
||||
if initial_states is not None else (0, 0, 0)),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
)
|
||||
return out, final_states
|
||||
|
||||
|
||||
def _state_passing_bwd(
|
||||
states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
|
||||
dstates_dtype=None, states_dtype=None, chunk_size=None
|
||||
):
|
||||
"""
|
||||
states contains the initial_states at index 0. The final states are not included in states.
|
||||
"""
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
||||
assert dout.shape == (batch, nchunks, nheads, dim)
|
||||
if seq_idx is not None:
|
||||
assert chunk_size is not None
|
||||
seqlen = seq_idx.shape[-1]
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
|
||||
if states_dtype is not None and states_dtype != states.dtype:
|
||||
states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
|
||||
assert states_converted.stride() == states.stride()
|
||||
else:
|
||||
states_converted = None
|
||||
if has_initial_states:
|
||||
dinitstates = torch.empty_like(dstates[:, 0])
|
||||
else:
|
||||
dinitstates = None
|
||||
if dfinal_states is not None:
|
||||
assert dfinal_states.shape == (batch, nheads, dim)
|
||||
BLOCK_SIZE_min = 64
|
||||
n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
|
||||
ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
|
||||
dtype=torch.float32, device=dA_chunk_cumsum.device)
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
||||
with torch.cuda.device(dout.device.index):
|
||||
_state_passing_bwd_kernel[grid](
|
||||
dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
|
||||
dstates, ddA_chunk_cumsum, dinitstates, states_converted,
|
||||
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
|
||||
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
|
||||
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
||||
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
|
||||
*((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
|
||||
if dfinal_states is not None else (0, 0, 0)),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
|
||||
ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
|
||||
*((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
|
||||
if dinitstates is not None else (0, 0, 0)),
|
||||
CONVERT_STATES=states_converted is not None,
|
||||
HAS_DFINAL_STATES=dfinal_states is not None,
|
||||
HAS_DINITSTATES=dinitstates is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
)
|
||||
BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
|
||||
n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
||||
ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
|
||||
if states_dtype is not None and states_dtype == states.dtype:
|
||||
states_converted = states
|
||||
return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
|
||||
|
||||
|
||||
class StatePassingFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
||||
if states.stride(-1) != 1:
|
||||
states = states.contiguous()
|
||||
out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
|
||||
ctx.save_for_backward(out, dA_chunk_cumsum)
|
||||
ctx.has_initial_states = initial_states is not None
|
||||
return out, final_states
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, dfinal_states):
|
||||
out, dA_chunk_cumsum = ctx.saved_tensors
|
||||
batch, nchunks, nheads, dim = out.shape
|
||||
assert dout.shape == (batch, nchunks, nheads, dim)
|
||||
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
||||
assert dfinal_states.shape == (batch, nheads, dim)
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
|
||||
out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
|
||||
)
|
||||
return dstates, ddA_chunk_cumsum, dinitstates
|
||||
|
||||
|
||||
def state_passing(states, dA_chunk_cumsum, initial_states=None):
|
||||
"""
|
||||
Argument:
|
||||
states: (batch, nchunks, nheads, dim)
|
||||
dA_chunk_cumsum: (batch, nheads, nchunks)
|
||||
initial_states: (batch, nheads, dim)
|
||||
Return:
|
||||
out: (batch, nchunks, nheads, dim)
|
||||
final_states: (batch, nheads, dim)
|
||||
"""
|
||||
return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
|
||||
|
||||
|
||||
def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
|
||||
"""
|
||||
Argument:
|
||||
states: (batch, nchunks, nheads, dim)
|
||||
dA_chunk_cumsum: (batch, nheads, nchunks)
|
||||
initial_states: (batch, nheads, dim)
|
||||
Return:
|
||||
out: (batch, nchunks, nheads, dim)
|
||||
final_states: (batch, nheads, dim)
|
||||
"""
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, 0])
|
||||
states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
|
||||
dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
|
||||
dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
|
||||
nchunks = dA_chunk_cumsum.shape[-1]
|
||||
# (batch, nheads, nchunks, nchunks)
|
||||
dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
|
||||
# (batch, nheads, nchunks, nchunks)
|
||||
decay_chunk = torch.exp(dt_chunk_segment_sum)
|
||||
causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
|
||||
decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
|
||||
out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
|
||||
return out[:, :-1], out[:, -1]
|
||||
0
Mamba/mamba-main/mamba_ssm/utils/__init__.py
Normal file
0
Mamba/mamba-main/mamba_ssm/utils/__init__.py
Normal file
387
Mamba/mamba-main/mamba_ssm/utils/generation.py
Normal file
387
Mamba/mamba-main/mamba_ssm/utils/generation.py
Normal file
@ -0,0 +1,387 @@
|
||||
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
||||
import gc
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
seqlen_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
def reset(self, max_seqlen, max_batch_size):
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_batch_size = max_batch_size
|
||||
self.seqlen_offset = 0
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample.zero_()
|
||||
|
||||
|
||||
def modify_logits_for_min_p_filtering(logits, min_p):
|
||||
"""Set the logits for none min_p values to -inf. Done in-place."""
|
||||
if min_p <= 0.0 or min_p >= 1.0:
|
||||
return
|
||||
indices_to_remove = logits < min_p
|
||||
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
||||
def modify_logits_for_top_k_filtering(logits, top_k):
|
||||
"""Set the logits for none top-k values to -inf. Done in-place."""
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf. Done in-place."""
|
||||
if top_p <= 0.0 or top_p >= 1.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits.masked_fill_(indices_to_remove, float("-inf"))
|
||||
|
||||
|
||||
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
||||
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
||||
logits: (batch_size, vocab_size)
|
||||
prev_output_tokens: (batch_size, seq_len)
|
||||
"""
|
||||
if repetition_penalty == 1.0:
|
||||
return logits
|
||||
score = torch.gather(logits, 1, prev_output_tokens)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||||
logits.scatter_(1, prev_output_tokens, score)
|
||||
return logits
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
||||
"""Sample from top-k logits.
|
||||
Arguments:
|
||||
logits: Tensor of shape (batch_size, vocab_size)
|
||||
"""
|
||||
if top_k == 1: # Short-circuit for greedy decoding
|
||||
return logits.argmax(dim=-1)
|
||||
else:
|
||||
if top_p > 0.0:
|
||||
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
if temperature != 1.0:
|
||||
logits_top /= temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
||||
]
|
||||
else:
|
||||
if min_p > 0.0:
|
||||
logits_top = logits.clone()
|
||||
max_prob = logits_top[..., 0].item()
|
||||
min_prob = max_prob * min_p
|
||||
modify_logits_for_min_p_filtering(logits_top, min_prob)
|
||||
if temperature != 1.0:
|
||||
logits_top /= temperature
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
# Clone so that when we modify for top_p we don't change the original logits
|
||||
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(
|
||||
input_ids,
|
||||
model,
|
||||
max_length,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
min_p=0.0,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
eos_token_id=None,
|
||||
teacher_outputs=None,
|
||||
vocab_size=None,
|
||||
cg=False,
|
||||
enable_timing=False,
|
||||
streamer: Optional[TextStreamer] = None
|
||||
):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
then top-p.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_length: int
|
||||
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
||||
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
||||
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
||||
if cg:
|
||||
if not hasattr(model, "_decoding_cache"):
|
||||
model._decoding_cache = None
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model,
|
||||
model._decoding_cache,
|
||||
batch_size,
|
||||
seqlen_og,
|
||||
max_length,
|
||||
)
|
||||
inference_params = model._decoding_cache.inference_params
|
||||
inference_params.reset(max_length, batch_size)
|
||||
else:
|
||||
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
|
||||
def get_logits(input_ids, inference_params):
|
||||
decoding = inference_params.seqlen_offset > 0
|
||||
if decoding:
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params.seqlen_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
else:
|
||||
position_ids = None
|
||||
if not cg or not decoding:
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
else:
|
||||
logits = model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.seqlen_offset
|
||||
).squeeze(dim=1)
|
||||
return logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
|
||||
def sample_tokens(logits, inference_params):
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
||||
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
||||
else:
|
||||
token = teacher_outputs[:, inference_params.seqlen_offset]
|
||||
# return rearrange(token, "b -> b 1")
|
||||
return token.unsqueeze(1)
|
||||
|
||||
def should_stop(current_token, inference_params):
|
||||
if inference_params.seqlen_offset == 0:
|
||||
return False
|
||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||
return True
|
||||
if inference_params.seqlen_offset >= max_length - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
start = torch.cuda.Event(enable_timing=enable_timing)
|
||||
end = torch.cuda.Event(enable_timing=enable_timing)
|
||||
|
||||
if enable_timing:
|
||||
start.record()
|
||||
scores, sequences = [], [input_ids]
|
||||
sequences_cat = input_ids
|
||||
while not should_stop(sequences[-1], inference_params):
|
||||
scores.append(get_logits(sequences[-1], inference_params))
|
||||
inference_params.seqlen_offset += sequences[-1].shape[1]
|
||||
if repetition_penalty == 1.0:
|
||||
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
||||
else:
|
||||
logits = modify_logit_for_repetition_penalty(
|
||||
scores[-1].clone(), sequences_cat, repetition_penalty
|
||||
)
|
||||
sampled_tokens = sample_tokens(logits, inference_params)
|
||||
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
||||
sequences.append(sampled_tokens)
|
||||
if streamer is not None:
|
||||
streamer.put(sampled_tokens.cpu())
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
if enable_timing:
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
max_length,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
min_p=0.0,
|
||||
temperature=1.0,
|
||||
return_dict_in_generate=False,
|
||||
output_scores=False,
|
||||
**kwargs,
|
||||
):
|
||||
output = decode(
|
||||
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs
|
||||
)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingCGCache:
|
||||
max_batch_size: int = 0
|
||||
max_seqlen: int = 0
|
||||
device = None
|
||||
dtype = None
|
||||
callables: dict = field(default_factory=dict)
|
||||
mempool = None
|
||||
inference_params: Optional[InferenceParams] = None
|
||||
run: Optional[Callable] = None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_graph_cache(
|
||||
model,
|
||||
cache,
|
||||
batch_size,
|
||||
seqlen_og,
|
||||
max_seqlen,
|
||||
decoding_seqlens=(1,),
|
||||
dtype=None,
|
||||
n_warmups=2,
|
||||
):
|
||||
if cache is None:
|
||||
cache = DecodingCGCache()
|
||||
param_example = next(iter(model.parameters()))
|
||||
device = param_example.device
|
||||
if dtype is None:
|
||||
dtype = param_example.dtype
|
||||
if (
|
||||
(device, dtype) != (cache.device, cache.dtype)
|
||||
or batch_size > cache.max_batch_size
|
||||
or max_seqlen > cache.max_seqlen
|
||||
): # Invalidate the cache
|
||||
cache.callables = {}
|
||||
cache.mempool = None
|
||||
cache.inference_params = None
|
||||
gc.collect()
|
||||
cache.device, cache.dtype = device, dtype
|
||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
||||
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||
cache.inference_params = InferenceParams(
|
||||
max_seqlen=max_seqlen,
|
||||
max_batch_size=batch_size,
|
||||
seqlen_offset=seqlen_og,
|
||||
key_value_memory_dict=inf_cache,
|
||||
lengths_per_sample=lengths_per_sample,
|
||||
)
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for decoding_seqlen in decoding_seqlens:
|
||||
if (batch_size, decoding_seqlen) not in cache.callables:
|
||||
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
||||
model,
|
||||
cache.inference_params,
|
||||
batch_size,
|
||||
max_seqlen,
|
||||
decoding_seqlen=decoding_seqlen,
|
||||
mempool=cache.mempool,
|
||||
n_warmups=n_warmups,
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
batch_size, decoding_seqlen = input_ids.shape[:2]
|
||||
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
||||
return cache
|
||||
|
||||
|
||||
def capture_graph(
|
||||
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
||||
):
|
||||
device = next(iter(model.parameters())).device
|
||||
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
||||
seqlen_offset_og = inference_params.seqlen_offset
|
||||
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
||||
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
||||
|
||||
# Warmup before capture
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(n_warmups):
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=decoding_seqlen,
|
||||
).logits
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
# which requires that graph launch and non-captured launch to not overlap (I think,
|
||||
# that's how I interpret the documentation). I'm not sure if this is required.
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
# Captures the graph
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=mempool):
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=decoding_seqlen,
|
||||
).logits
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
input_ids.copy_(new_input_ids)
|
||||
position_ids.copy_(new_position_ids)
|
||||
graph.replay()
|
||||
return logits.clone()
|
||||
|
||||
inference_params.seqlen_offset = seqlen_offset_og
|
||||
return run
|
||||
23
Mamba/mamba-main/mamba_ssm/utils/hf.py
Normal file
23
Mamba/mamba-main/mamba_ssm/utils/hf.py
Normal file
@ -0,0 +1,23 @@
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
def load_config_hf(model_name):
|
||||
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
||||
return json.load(open(resolved_archive_file))
|
||||
|
||||
|
||||
def load_state_dict_hf(model_name, device=None, dtype=None):
|
||||
# If not fp32, then we don't want to load directly to the GPU
|
||||
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
||||
return torch.load(resolved_archive_file, map_location=mapped_device)
|
||||
# Convert dtype before moving to GPU to save memory
|
||||
if dtype is not None:
|
||||
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
56
Mamba/mamba-main/rocm_patch/rocm6_0.patch
Normal file
56
Mamba/mamba-main/rocm_patch/rocm6_0.patch
Normal file
@ -0,0 +1,56 @@
|
||||
--- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000
|
||||
+++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000
|
||||
@@ -137,7 +137,7 @@
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT16_CONV
|
||||
* \brief Converts float to bfloat16
|
||||
*/
|
||||
-__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
|
||||
+__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) {
|
||||
__hip_bfloat16 ret;
|
||||
union {
|
||||
float fp32;
|
||||
@@ -181,7 +181,7 @@
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
* \brief Converts and moves bfloat162 to float2
|
||||
*/
|
||||
-__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
||||
+__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
||||
return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
|
||||
}
|
||||
|
||||
@@ -209,7 +209,7 @@
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
* \brief Convert double to __hip_bfloat16
|
||||
*/
|
||||
-__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {
|
||||
+__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) {
|
||||
return __float2bfloat16((float)a);
|
||||
}
|
||||
|
||||
@@ -217,7 +217,7 @@
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
* \brief Convert float2 to __hip_bfloat162
|
||||
*/
|
||||
-__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
||||
+__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
||||
return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
|
||||
}
|
||||
|
||||
@@ -247,7 +247,7 @@
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
* \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
|
||||
*/
|
||||
-__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
|
||||
+__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
@@ -275,7 +275,7 @@
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
* \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
|
||||
*/
|
||||
-__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
|
||||
+__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
|
||||
|
||||
/**
|
||||
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
||||
398
Mamba/mamba-main/setup.py
Normal file
398
Mamba/mamba-main/setup.py
Normal file
@ -0,0 +1,398 @@
|
||||
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
import re
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from packaging.version import parse, Version
|
||||
import platform
|
||||
import shutil
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CppExtension,
|
||||
CUDAExtension,
|
||||
CUDA_HOME,
|
||||
HIP_HOME
|
||||
)
|
||||
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
PACKAGE_NAME = "mamba_ssm"
|
||||
|
||||
BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"
|
||||
|
||||
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
||||
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
||||
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
|
||||
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
||||
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
||||
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
||||
|
||||
|
||||
def get_platform():
|
||||
"""
|
||||
Returns the platform name as used in wheel filenames.
|
||||
"""
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux_x86_64"
|
||||
elif sys.platform == "darwin":
|
||||
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
|
||||
return f"macosx_{mac_version}_x86_64"
|
||||
elif sys.platform == "win32":
|
||||
return "win_amd64"
|
||||
else:
|
||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
bare_metal_ver = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_ver
|
||||
|
||||
|
||||
def get_hip_version(rocm_dir):
|
||||
|
||||
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
|
||||
try:
|
||||
raw_output = subprocess.check_output(
|
||||
[hipcc_bin, "--version"], universal_newlines=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
|
||||
)
|
||||
return None, None
|
||||
|
||||
for line in raw_output.split("\n"):
|
||||
if "HIP version" in line:
|
||||
rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
|
||||
return line, rocm_version
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def get_torch_hip_version():
|
||||
|
||||
if torch.version.hip:
|
||||
return parse(torch.version.hip.split()[-1].replace("-", "+"))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_if_hip_home_none(global_option: str) -> None:
|
||||
|
||||
if HIP_HOME is not None:
|
||||
return
|
||||
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
|
||||
# in that case.
|
||||
warnings.warn(
|
||||
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
|
||||
)
|
||||
|
||||
|
||||
def check_if_cuda_home_none(global_option: str) -> None:
|
||||
if CUDA_HOME is not None:
|
||||
return
|
||||
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
||||
# in that case.
|
||||
warnings.warn(
|
||||
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
||||
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
||||
"only images whose names contain 'devel' will provide nvcc."
|
||||
)
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
|
||||
|
||||
cmdclass = {}
|
||||
ext_modules = []
|
||||
|
||||
|
||||
HIP_BUILD = bool(torch.version.hip)
|
||||
|
||||
if not SKIP_CUDA_BUILD:
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
|
||||
cc_flag = []
|
||||
|
||||
if HIP_BUILD:
|
||||
check_if_hip_home_none(PACKAGE_NAME)
|
||||
|
||||
rocm_home = os.getenv("ROCM_PATH")
|
||||
_, hip_version = get_hip_version(rocm_home)
|
||||
|
||||
if HIP_HOME is not None:
|
||||
if hip_version < Version("6.0"):
|
||||
raise RuntimeError(
|
||||
f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
|
||||
"Note: make sure HIP has a supported version by running hipcc --version."
|
||||
)
|
||||
if hip_version == Version("6.0"):
|
||||
warnings.warn(
|
||||
f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
|
||||
"Refer to the README.md for detailed instructions.",
|
||||
UserWarning
|
||||
)
|
||||
|
||||
cc_flag.append("-DBUILD_PYTHON_PACKAGE")
|
||||
|
||||
else:
|
||||
check_if_cuda_home_none(PACKAGE_NAME)
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
|
||||
if CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.6"):
|
||||
raise RuntimeError(
|
||||
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
||||
"Note: make sure nvcc has a supported version by running nvcc -V."
|
||||
)
|
||||
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_53,code=sm_53")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_62,code=sm_62")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_72,code=sm_72")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_87,code=sm_87")
|
||||
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
|
||||
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
||||
# torch._C._GLIBCXX_USE_CXX11_ABI
|
||||
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
||||
if FORCE_CXX11_ABI:
|
||||
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
||||
|
||||
if HIP_BUILD:
|
||||
|
||||
try:
|
||||
# set warp size based on gcn architecure
|
||||
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
||||
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
||||
# radeon
|
||||
warp_size = 32
|
||||
else:
|
||||
# instinct
|
||||
warp_size = 64
|
||||
except AttributeError as e:
|
||||
# fall back to crude method to set warp size
|
||||
device_name = torch.cuda.get_device_properties(0).name
|
||||
if 'instinct' in device_name.lower():
|
||||
warp_size = 64
|
||||
else:
|
||||
warp_size = 32
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": ["-O3", "-std=c++17"],
|
||||
"nvcc": [
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-DCK_FMHA_FWD_FAST_EXP2=1",
|
||||
"-fgpu-flush-denormals-to-zero",
|
||||
f"-DROCM_WARP_SIZE={warp_size}"
|
||||
]
|
||||
+ cc_flag,
|
||||
}
|
||||
else:
|
||||
extra_compile_args = {
|
||||
"cxx": ["-O3", "-std=c++17"],
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v",
|
||||
"-lineinfo",
|
||||
]
|
||||
+ cc_flag
|
||||
),
|
||||
}
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="selective_scan_cuda",
|
||||
sources=[
|
||||
"csrc/selective_scan/selective_scan.cpp",
|
||||
"csrc/selective_scan/selective_scan_fwd_fp32.cu",
|
||||
"csrc/selective_scan/selective_scan_fwd_fp16.cu",
|
||||
"csrc/selective_scan/selective_scan_fwd_bf16.cu",
|
||||
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
|
||||
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
|
||||
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
|
||||
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
|
||||
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
|
||||
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
|
||||
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
||||
public_version = ast.literal_eval(version_match.group(1))
|
||||
local_version = os.environ.get("MAMBA_LOCAL_VERSION")
|
||||
if local_version:
|
||||
return f"{public_version}+{local_version}"
|
||||
else:
|
||||
return str(public_version)
|
||||
|
||||
|
||||
def get_wheel_url():
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
torch_version_raw = parse(torch.__version__)
|
||||
|
||||
if HIP_BUILD:
|
||||
# We're using the HIP version used to build torch, not the one currently installed
|
||||
torch_hip_version = get_torch_hip_version()
|
||||
hip_ver = f"{torch_hip_version.major}{torch_hip_version.minor}"
|
||||
else:
|
||||
# We're using the CUDA version used to build torch, not the one currently installed
|
||||
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
torch_cuda_version = parse(torch.version.cuda)
|
||||
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
|
||||
# to save CI time. Minor versions should be compatible.
|
||||
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
|
||||
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
||||
|
||||
gpu_compute_version = hip_ver if HIP_BUILD else cuda_version
|
||||
cuda_or_hip = "hip" if HIP_BUILD else "cu"
|
||||
|
||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
platform_name = get_platform()
|
||||
mamba_ssm_version = get_package_version()
|
||||
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
||||
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
|
||||
)
|
||||
return wheel_url, wheel_filename
|
||||
|
||||
|
||||
class CachedWheelsCommand(_bdist_wheel):
|
||||
"""
|
||||
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
||||
find an existing wheel (which is currently the case for all installs). We use
|
||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||
wheel available and short-circuits the standard full build pipeline.
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
if FORCE_BUILD:
|
||||
return super().run()
|
||||
|
||||
wheel_url, wheel_filename = get_wheel_url()
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
|
||||
# Make the archive
|
||||
# Lifted from the root wheel processing command
|
||||
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
|
||||
if not os.path.exists(self.dist_dir):
|
||||
os.makedirs(self.dist_dir)
|
||||
|
||||
impl_tag, abi_tag, plat_tag = self.get_tag()
|
||||
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
|
||||
|
||||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||
print("Raw wheel path", wheel_path)
|
||||
shutil.move(wheel_filename, wheel_path)
|
||||
except urllib.error.HTTPError:
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
super().run()
|
||||
|
||||
setup(
|
||||
name=PACKAGE_NAME,
|
||||
version=get_package_version(),
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"build",
|
||||
"csrc",
|
||||
"include",
|
||||
"tests",
|
||||
"dist",
|
||||
"docs",
|
||||
"benchmarks",
|
||||
"mamba_ssm.egg-info",
|
||||
)
|
||||
),
|
||||
author="Tri Dao, Albert Gu",
|
||||
author_email="tri@tridao.me, agu@cs.cmu.edu",
|
||||
description="Mamba state-space model",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/state-spaces/mamba",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
"Operating System :: Unix",
|
||||
],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
|
||||
if ext_modules
|
||||
else {
|
||||
"bdist_wheel": CachedWheelsCommand,
|
||||
},
|
||||
python_requires=">=3.7",
|
||||
install_requires=[
|
||||
"torch",
|
||||
"packaging",
|
||||
"ninja",
|
||||
"einops",
|
||||
"triton",
|
||||
"transformers",
|
||||
# "causal_conv1d>=1.2.0",
|
||||
],
|
||||
)
|
||||
247
Mamba/mamba-main/tests/ops/test_selective_scan.py
Normal file
247
Mamba/mamba-main/tests/ops/test_selective_scan.py
Normal file
@ -0,0 +1,247 @@
|
||||
# Copyright (C) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
|
||||
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
|
||||
|
||||
|
||||
# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
||||
# @pytest.mark.parametrize('seqlen', [128])
|
||||
# @pytest.mark.parametrize("return_last_state", [False, True])
|
||||
@pytest.mark.parametrize("return_last_state", [True])
|
||||
# @pytest.mark.parametrize('has_delta_bias', [False, True])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
# @pytest.mark.parametrize('delta_softplus', [False, True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
# @pytest.mark.parametrize('has_z', [False, True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
# @pytest.mark.parametrize('has_D', [False, True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
# @pytest.mark.parametrize("varBC_groups", [1])
|
||||
# @pytest.mark.parametrize("is_variable_C", [False, True])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
# @pytest.mark.parametrize("is_variable_B", [False, True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
|
||||
delta_softplus, return_last_state, seqlen, itype, wtype):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
rtolw, atolw = (1e-3, 1e-3)
|
||||
if has_z: # If we have z, the errors on the weights seem higher
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
dim = 4
|
||||
dstate = 8
|
||||
is_complex = wtype == torch.complex64
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
|
||||
if not is_variable_B:
|
||||
B_shape = (dim, dstate)
|
||||
elif varBC_groups == 1:
|
||||
B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
|
||||
else:
|
||||
B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
|
||||
B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
|
||||
requires_grad=True)
|
||||
if not is_variable_C:
|
||||
C_shape = (dim, dstate)
|
||||
elif varBC_groups == 1:
|
||||
C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
|
||||
else:
|
||||
C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
|
||||
C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
|
||||
requires_grad=True)
|
||||
if has_D:
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
||||
else:
|
||||
D = None
|
||||
if has_z:
|
||||
z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
z = None
|
||||
if has_delta_bias:
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
|
||||
else:
|
||||
delta_bias = None
|
||||
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
|
||||
delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
|
||||
A_ref = A.detach().clone().requires_grad_()
|
||||
B_ref = B.detach().clone().requires_grad_()
|
||||
C_ref = C.detach().clone().requires_grad_()
|
||||
D_ref = D.detach().clone().requires_grad_() if D is not None else None
|
||||
z_ref = z.detach().clone().requires_grad_() if z is not None else None
|
||||
u_ref = u.detach().clone().requires_grad_()
|
||||
delta_ref = delta.detach().clone().requires_grad_()
|
||||
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
|
||||
out, *rest = selective_scan_fn(
|
||||
u, delta, A, B, C, D, z=z,
|
||||
delta_bias=delta_bias, delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state
|
||||
)
|
||||
if return_last_state:
|
||||
state = rest[0]
|
||||
out_ref, *rest = selective_scan_ref(
|
||||
u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
|
||||
delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state
|
||||
)
|
||||
if return_last_state:
|
||||
state_ref = rest[0]
|
||||
# dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
# dt_u = delta * u
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
if return_last_state:
|
||||
print(f'State max diff: {(state - state_ref).abs().max().item()}')
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
|
||||
g = torch.randn_like(out)
|
||||
out_ref.backward(g)
|
||||
out.backward(g)
|
||||
|
||||
print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
|
||||
print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
|
||||
print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
|
||||
print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
|
||||
print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
|
||||
if has_D:
|
||||
print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
|
||||
if has_z:
|
||||
print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
|
||||
if has_delta_bias:
|
||||
print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
|
||||
|
||||
assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
|
||||
assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
|
||||
assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
|
||||
assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
|
||||
atol=atolw if not is_variable_B else atol)
|
||||
assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
|
||||
atol=atolw if not is_variable_C else atol)
|
||||
if has_D:
|
||||
assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
|
||||
if has_z:
|
||||
assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
|
||||
if has_delta_bias:
|
||||
assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
|
||||
# @pytest.mark.parametrize('wtype', [torch.complex64])
|
||||
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
|
||||
@pytest.mark.parametrize('seqlen', [128])
|
||||
@pytest.mark.parametrize("is_variable_C", [False, True])
|
||||
# @pytest.mark.parametrize("is_variable_C", [False])
|
||||
@pytest.mark.parametrize("is_variable_B", [False, True])
|
||||
# @pytest.mark.parametrize("is_variable_B", [True])
|
||||
def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
rtolw, atolw = (1e-3, 1e-3)
|
||||
# If we have z, the errors on the weights seem higher
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
dim = 768
|
||||
dstate = 8
|
||||
dt_rank = 48
|
||||
is_complex = wtype == torch.complex64
|
||||
xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
|
||||
conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
|
||||
conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
||||
x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
|
||||
* (1 if not is_complex else 2),
|
||||
dim, device=device, dtype=itype, requires_grad=True)
|
||||
delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
|
||||
out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
|
||||
out_proj_bias = None
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
|
||||
B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
|
||||
if not is_variable_B else None)
|
||||
C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
|
||||
if not is_variable_C else None)
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
|
||||
B_proj_bias = None
|
||||
C_proj_bias = None
|
||||
xz_ref = xz.detach().clone().requires_grad_()
|
||||
conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
|
||||
conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
|
||||
x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
|
||||
delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
|
||||
out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
|
||||
out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
|
||||
if out_proj_bias is not None else None)
|
||||
A_ref = A.detach().clone().requires_grad_()
|
||||
B_ref = B.detach().clone().requires_grad_() if B is not None else None
|
||||
C_ref = C.detach().clone().requires_grad_() if C is not None else None
|
||||
D_ref = D.detach().clone().requires_grad_()
|
||||
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
|
||||
out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
||||
out_proj_weight, out_proj_bias,
|
||||
A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
|
||||
out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
|
||||
delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
|
||||
A_ref, B_ref, C_ref, D_ref,
|
||||
delta_bias=delta_bias_ref, delta_softplus=True)
|
||||
# dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
# dt_u = delta * u
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
g = torch.randn_like(out)
|
||||
out_ref.backward(g)
|
||||
out.backward(g)
|
||||
|
||||
print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
|
||||
print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
|
||||
if not is_variable_B:
|
||||
print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
|
||||
if not is_variable_C:
|
||||
print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
|
||||
print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
|
||||
print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
|
||||
print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
|
||||
print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
|
||||
print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
|
||||
print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
|
||||
print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
|
||||
|
||||
# assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
|
||||
# assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
|
||||
# assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
|
||||
# assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
|
||||
# atol=atolw if not is_variable_B else atol)
|
||||
# assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
|
||||
# atol=atolw if not is_variable_C else atol)
|
||||
# assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
|
||||
# assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
|
||||
@ -0,0 +1,53 @@
|
||||
# Copyright (C) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
# @pytest.mark.parametrize('itype', [torch.float16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
# @pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
# @pytest.mark.parametrize("dstate", [16])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# @pytest.mark.parametrize("dim", [2048])
|
||||
def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
if has_z:
|
||||
z = torch.randn_like(x)
|
||||
else:
|
||||
z = None
|
||||
state_ref = state.detach().clone()
|
||||
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
||||
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
99
Mamba/mamba-main/train.py
Normal file
99
Mamba/mamba-main/train.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, MambaConfig
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from datasets import Dataset
|
||||
|
||||
# 设置环境变量来避免内存碎片化
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
||||
|
||||
# 数据文件夹路径
|
||||
data_folder = r'/mnt/Mamba/mamba-main/data/dataset'
|
||||
|
||||
# 检查路径是否存在
|
||||
if not os.path.exists(data_folder):
|
||||
raise ValueError(f"路径不存在: {data_folder}")
|
||||
|
||||
# 加载分词器和模型
|
||||
path = "/mnt/Mamba/mamba-130m-hf" # 模型路径
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(path, local_files_only=True, num_labels=8, use_mambapy=True)
|
||||
|
||||
print("加载成功")
|
||||
|
||||
# 配置训练参数
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=12, # 减少批处理大小
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3,
|
||||
gradient_accumulation_steps=2, # 使用梯度累积减少显存占用
|
||||
fp16=True, # 启用混合精度训练
|
||||
)
|
||||
|
||||
# LoRA配置
|
||||
lora_config = LoraConfig(
|
||||
r=8, # 低秩分解的秩
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="SEQ_CLS", # 序列分类任务类型
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# 初始化Trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
max_seq_length=512, # 设置max_seq_length参数
|
||||
)
|
||||
|
||||
# 分块加载和处理数据
|
||||
chunksize = 40000 # 设置合适的分块大小,每次读取数据的行数
|
||||
|
||||
|
||||
def preprocess_data(chunk):
|
||||
chunk = chunk.dropna() # 处理缺失值
|
||||
texts = chunk[["acc_x", "acc_y", "acc_z", "gyr_x", "gyr_y", "gyr_z", "mag_x", "mag_y", "mag_z"]].astype(str).apply(
|
||||
' '.join, axis=1).tolist()
|
||||
labels = chunk["Person_id"].astype(int).tolist() # 确保标签是整数类型
|
||||
encodings = tokenizer(texts, truncation=True, padding=True, max_length=1024)
|
||||
return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels}
|
||||
|
||||
|
||||
# 读取训练数据并进行训练
|
||||
train_file_path = os.path.join(data_folder, 'train_data.csv')
|
||||
chunk_iter = pd.read_csv(train_file_path, chunksize=chunksize, header=0)
|
||||
|
||||
for chunk in chunk_iter:
|
||||
# 数据预处理
|
||||
processed_data = preprocess_data(chunk)
|
||||
dataset = Dataset.from_dict(processed_data)
|
||||
|
||||
# 训练模型
|
||||
trainer.train_dataset = dataset
|
||||
trainer.train()
|
||||
|
||||
# 清理CUDA缓存
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 保存训练后的模型
|
||||
model.save_pretrained("./trained_model")
|
||||
tokenizer.save_pretrained("./trained_model")
|
||||
|
||||
print("模型保存成功")
|
||||
|
||||
# 读取测试数据并进行预测
|
||||
test_file_path = os.path.join(data_folder, 'test_data.csv')
|
||||
test_data = pd.read_csv(test_file_path, header=0)
|
||||
processed_test_data = preprocess_data(test_data)
|
||||
test_dataset = Dataset.from_dict(processed_test_data)
|
||||
|
||||
# 预测Person_id
|
||||
predictions = trainer.predict(test_dataset)
|
||||
|
||||
# 输出预测结果
|
||||
print(predictions)
|
||||
Loading…
x
Reference in New Issue
Block a user