Compare commits

...

2 Commits

Author SHA1 Message Date
carry
0722748997 feat(train_page): 添加 LoRA 秩动态输入功能
- 在训练页面新增 LoRA 秩输入框,使用户可以动态设置 LoRA 秩
- 更新训练模型函数,添加 LoRA 秩参数并将其用于模型配置
- 保留原有功能,仅增加 LoRA 秩相关配置
2025-04-13 21:12:02 +08:00
carry
e08f0059bb feat(train_page): 优化训练过程以专注于响应生成
- 引入 train_on_responses_only 函数,用于优化训练过程
- 设置 instruction_part 和 response_part 参数,以适应特定的对话格式
- 此修改旨在提高模型在生成响应方面的性能和效率
2025-04-13 21:05:14 +08:00

View File

@ -9,7 +9,7 @@ from unsloth import FastLanguageModel
from trl import SFTTrainer # 用于监督微调的训练器
from transformers import TrainingArguments,DataCollatorForSeq2Seq # 用于配置训练参数
from unsloth import is_bfloat16_supported # 检查是否支持bfloat16精度训练
from unsloth.chat_templates import get_chat_template
from unsloth.chat_templates import get_chat_template, train_on_responses_only
from tools import formatting_prompts_func
sys.path.append(str(Path(__file__).resolve().parent.parent))
@ -36,18 +36,20 @@ def train_page():
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
train_button = gr.Button("开始微调")
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps):
def train_model(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数
learning_rate = float(learning_rate)
per_device_train_batch_size = int(per_device_train_batch_size)
epoch = int(epoch)
save_steps = int(save_steps) # 新增保存步数参数
lora_rank = int(lora_rank) # 新增LoRA秩参数
# 模型配置参数
dtype = None # 数据类型None表示自动选择
@ -62,7 +64,7 @@ def train_page():
model,
# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大
# 建议: 8-32之间
r=16,
r=lora_rank, # 使用动态传入的LoRA秩
# 需要应用LoRA的目标模块列表
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # attention相关层
@ -132,7 +134,13 @@ def train_page():
# report_to="tensorboard", # 将信息输出到tensorboard
),
)
trainer = train_on_responses_only(
trainer,
instruction_part = "<|im_start|>user\n",
response_part = "<|im_start|>assistant\n",
)
# 开始训练
trainer_stats = trainer.train(resume_from_checkpoint=False)
@ -155,6 +163,5 @@ if __name__ == "__main__":
from model_manage_page import model_manage_page
init_global_var("workdir")
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
# demo = gr.TabbedInterface([ train_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()