更新 Mamba/mamba-main/train.py
parent
9eea6c07af
commit
f65b091fac
|
@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
import torch
|
||||||
# 设置环境变量来避免内存碎片化
|
# 设置环境变量来避免内存碎片化
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue