diff --git a/frontend/train_page.py b/frontend/train_page.py index 3e7fcbb..d312dc8 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -9,7 +9,7 @@ from unsloth import FastLanguageModel from trl import SFTTrainer # 用于监督微调的训练器 from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数 from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练 -from unsloth.chat_templates import get_chat_template +from unsloth.chat_templates import get_chat_template, train_on_responses_only from tools import formatting_prompts_func sys.path.append(str(Path(__file__).resolve().parent.parent)) @@ -132,7 +132,13 @@ def train_page(): # report_to="tensorboard", # 将信息输出到tensorboard ), ) - + + trainer = train_on_responses_only( + trainer, + instruction_part = "<|im_start|>user\n", + response_part = "<|im_start|>assistant\n", + ) + # 开始训练 trainer_stats = trainer.train(resume_from_checkpoint=False)