更新 Mamba/mamba-main/train.py

main
rzzn 2024-07-30 17:53:00 +08:00
parent 9eea6c07af
commit f65b091fac
1 changed files with 99 additions and 99 deletions

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
from trl import SFTTrainer from trl import SFTTrainer
from peft import LoraConfig from peft import LoraConfig
from datasets import Dataset from datasets import Dataset
import torch
# 设置环境变量来避免内存碎片化 # 设置环境变量来避免内存碎片化
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"