From e08f0059bbb49b75d0027b47f06b341071c7127a Mon Sep 17 00:00:00 2001 From: carry Date: Sun, 13 Apr 2025 21:05:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(train=5Fpage):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E8=BF=87=E7=A8=8B=E4=BB=A5=E4=B8=93=E6=B3=A8?= =?UTF-8?q?=E4=BA=8E=E5=93=8D=E5=BA=94=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入 train_on_responses_only 函数,用于优化训练过程 - 设置 instruction_part 和 response_part 参数,以适应特定的对话格式 - 此修改旨在提高模型在生成响应方面的性能和效率 --- frontend/train_page.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 3e7fcbb..d312dc8 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -9,7 +9,7 @@ from unsloth import FastLanguageModel from trl import SFTTrainer # 用于监督微调的训练器 from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数 from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练 -from unsloth.chat_templates import get_chat_template +from unsloth.chat_templates import get_chat_template, train_on_responses_only from tools import formatting_prompts_func sys.path.append(str(Path(__file__).resolve().parent.parent)) @@ -132,7 +132,13 @@ def train_page(): # report_to="tensorboard", # 将信息输出到tensorboard ), ) - + + trainer = train_on_responses_only( + trainer, + instruction_part = "<|im_start|>user\n", + response_part = "<|im_start|>assistant\n", + ) + # 开始训练 trainer_stats = trainer.train(resume_from_checkpoint=False)