更新 Mamba/mamba-main/train.py
parent
9eea6c07af
commit
f65b091fac
|
@ -1,99 +1,99 @@
|
||||||
import os
|
import os
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, MambaConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, MambaConfig
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
import torch
|
||||||
# 设置环境变量来避免内存碎片化
|
# 设置环境变量来避免内存碎片化
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
||||||
|
|
||||||
# 数据文件夹路径
|
# 数据文件夹路径
|
||||||
data_folder = r'/mnt/Mamba/mamba-main/data/dataset'
|
data_folder = r'/mnt/Mamba/mamba-main/data/dataset'
|
||||||
|
|
||||||
# 检查路径是否存在
|
# 检查路径是否存在
|
||||||
if not os.path.exists(data_folder):
|
if not os.path.exists(data_folder):
|
||||||
raise ValueError(f"路径不存在: {data_folder}")
|
raise ValueError(f"路径不存在: {data_folder}")
|
||||||
|
|
||||||
# 加载分词器和模型
|
# 加载分词器和模型
|
||||||
path = "/mnt/Mamba/mamba-130m-hf" # 模型路径
|
path = "/mnt/Mamba/mamba-130m-hf" # 模型路径
|
||||||
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
|
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(path, local_files_only=True, num_labels=8, use_mambapy=True)
|
model = AutoModelForCausalLM.from_pretrained(path, local_files_only=True, num_labels=8, use_mambapy=True)
|
||||||
|
|
||||||
print("加载成功")
|
print("加载成功")
|
||||||
|
|
||||||
# 配置训练参数
|
# 配置训练参数
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir="./results",
|
output_dir="./results",
|
||||||
num_train_epochs=3,
|
num_train_epochs=3,
|
||||||
per_device_train_batch_size=12, # 减少批处理大小
|
per_device_train_batch_size=12, # 减少批处理大小
|
||||||
logging_dir='./logs',
|
logging_dir='./logs',
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
learning_rate=2e-3,
|
learning_rate=2e-3,
|
||||||
gradient_accumulation_steps=2, # 使用梯度累积减少显存占用
|
gradient_accumulation_steps=2, # 使用梯度累积减少显存占用
|
||||||
fp16=True, # 启用混合精度训练
|
fp16=True, # 启用混合精度训练
|
||||||
)
|
)
|
||||||
|
|
||||||
# LoRA配置
|
# LoRA配置
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=8, # 低秩分解的秩
|
r=8, # 低秩分解的秩
|
||||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||||
task_type="SEQ_CLS", # 序列分类任务类型
|
task_type="SEQ_CLS", # 序列分类任务类型
|
||||||
bias="none"
|
bias="none"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化Trainer
|
# 初始化Trainer
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
peft_config=lora_config,
|
peft_config=lora_config,
|
||||||
max_seq_length=512, # 设置max_seq_length参数
|
max_seq_length=512, # 设置max_seq_length参数
|
||||||
)
|
)
|
||||||
|
|
||||||
# 分块加载和处理数据
|
# 分块加载和处理数据
|
||||||
chunksize = 40000 # 设置合适的分块大小,每次读取数据的行数
|
chunksize = 40000 # 设置合适的分块大小,每次读取数据的行数
|
||||||
|
|
||||||
|
|
||||||
def preprocess_data(chunk):
|
def preprocess_data(chunk):
|
||||||
chunk = chunk.dropna() # 处理缺失值
|
chunk = chunk.dropna() # 处理缺失值
|
||||||
texts = chunk[["acc_x", "acc_y", "acc_z", "gyr_x", "gyr_y", "gyr_z", "mag_x", "mag_y", "mag_z"]].astype(str).apply(
|
texts = chunk[["acc_x", "acc_y", "acc_z", "gyr_x", "gyr_y", "gyr_z", "mag_x", "mag_y", "mag_z"]].astype(str).apply(
|
||||||
' '.join, axis=1).tolist()
|
' '.join, axis=1).tolist()
|
||||||
labels = chunk["Person_id"].astype(int).tolist() # 确保标签是整数类型
|
labels = chunk["Person_id"].astype(int).tolist() # 确保标签是整数类型
|
||||||
encodings = tokenizer(texts, truncation=True, padding=True, max_length=1024)
|
encodings = tokenizer(texts, truncation=True, padding=True, max_length=1024)
|
||||||
return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels}
|
return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels}
|
||||||
|
|
||||||
|
|
||||||
# 读取训练数据并进行训练
|
# 读取训练数据并进行训练
|
||||||
train_file_path = os.path.join(data_folder, 'train_data.csv')
|
train_file_path = os.path.join(data_folder, 'train_data.csv')
|
||||||
chunk_iter = pd.read_csv(train_file_path, chunksize=chunksize, header=0)
|
chunk_iter = pd.read_csv(train_file_path, chunksize=chunksize, header=0)
|
||||||
|
|
||||||
for chunk in chunk_iter:
|
for chunk in chunk_iter:
|
||||||
# 数据预处理
|
# 数据预处理
|
||||||
processed_data = preprocess_data(chunk)
|
processed_data = preprocess_data(chunk)
|
||||||
dataset = Dataset.from_dict(processed_data)
|
dataset = Dataset.from_dict(processed_data)
|
||||||
|
|
||||||
# 训练模型
|
# 训练模型
|
||||||
trainer.train_dataset = dataset
|
trainer.train_dataset = dataset
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# 清理CUDA缓存
|
# 清理CUDA缓存
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# 保存训练后的模型
|
# 保存训练后的模型
|
||||||
model.save_pretrained("./trained_model")
|
model.save_pretrained("./trained_model")
|
||||||
tokenizer.save_pretrained("./trained_model")
|
tokenizer.save_pretrained("./trained_model")
|
||||||
|
|
||||||
print("模型保存成功")
|
print("模型保存成功")
|
||||||
|
|
||||||
# 读取测试数据并进行预测
|
# 读取测试数据并进行预测
|
||||||
test_file_path = os.path.join(data_folder, 'test_data.csv')
|
test_file_path = os.path.join(data_folder, 'test_data.csv')
|
||||||
test_data = pd.read_csv(test_file_path, header=0)
|
test_data = pd.read_csv(test_file_path, header=0)
|
||||||
processed_test_data = preprocess_data(test_data)
|
processed_test_data = preprocess_data(test_data)
|
||||||
test_dataset = Dataset.from_dict(processed_test_data)
|
test_dataset = Dataset.from_dict(processed_test_data)
|
||||||
|
|
||||||
# 预测Person_id
|
# 预测Person_id
|
||||||
predictions = trainer.predict(test_dataset)
|
predictions = trainer.predict(test_dataset)
|
||||||
|
|
||||||
# 输出预测结果
|
# 输出预测结果
|
||||||
print(predictions)
|
print(predictions)
|
||||||
|
|
Loading…
Reference in New Issue