From 088067d335f4fad93d993f3d42a2f43de2a754f4 Mon Sep 17 00:00:00 2001 From: carry Date: Mon, 14 Apr 2025 16:19:37 +0800 Subject: [PATCH] =?UTF-8?q?train:=20=E6=9B=B4=E6=96=B0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=8A=9F=E8=83=BD=E5=92=8C=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修改训练目录结构,将检查点和日志分开保存 - 添加 TensorBoard 日志记录支持 - 移除自定义 LossCallback 类,简化训练流程 - 更新训练参数和回调机制,提高代码可读性 - 在 requirements.txt 中添加 tensorboardX 依赖 --- frontend/train_page.py | 26 +++----------------------- requirements.txt | 3 ++- tools/model.py | 14 ++++++++------ 3 files changed, 13 insertions(+), 30 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index a8f261c..e5db3c1 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -45,31 +45,11 @@ def train_page(): # 加载数据集 dataset = get_datasets().get(Query().name == dataset_name) dataset = [ds["message"][0] for ds in dataset["dataset_items"]] - - class LossCallback(TrainerCallback): - def __init__(self): - self.loss_data = [] - self.log_text = "" - self.last_output = {"text": "", "plot": None} - def on_log(self, args, state, control, logs=None, **kwargs): - if "loss" in logs: - self.loss_data.append({ - "step": state.global_step, - "loss": float(logs["loss"]) - }) - self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n" - # 添加以下两行print语句 - print(f"Current Step: {state.global_step}") - print(f"Loss Value: {logs['loss']:.4f}") - self.last_output = { - "text": self.log_text, - } - # 不返回 control,避免干预训练过程 train_model(get_model(), get_tokenizer(), - dataset, get_workdir()+"/checkpoint", + dataset, get_workdir() + "/1", learning_rate, per_device_train_batch_size, epoch, - save_steps, lora_rank, LossCallback) + save_steps, lora_rank) train_button.click( @@ -80,7 +60,7 @@ def train_page(): per_device_train_batch_size_input, epoch_input, save_steps_input, - lora_rank_input # 新增lora_rank_input + lora_rank_input ], outputs=output ) diff --git a/requirements.txt b/requirements.txt index ad485ba..5671485 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ langchain>=0.3 tinydb>=4.0.0 unsloth>=2025.3.19 sqlmodel>=0.0.24 -jinja2>=3.1.0 \ No newline at end of file +jinja2>=3.1.0 +tensorboardX>=2.6.2.2 \ No newline at end of file diff --git a/tools/model.py b/tools/model.py index 062b222..2b70c48 100644 --- a/tools/model.py +++ b/tools/model.py @@ -35,13 +35,13 @@ def train_model( model, tokenizer, dataset: list, - output_dir: str, + train_dir: str, learning_rate: float, per_device_train_batch_size: int, epoch: int, save_steps: int, lora_rank: int, - trainer_callback + trainer_callback=None ) -> None: # 模型配置参数 dtype = None # 数据类型,None表示自动选择 @@ -113,14 +113,16 @@ def train_model( optim="adamw_8bit", # 使用8位AdamW优化器节省显存,几乎不影响训练效果 weight_decay=0.01, # 权重衰减系数,用于正则化,防止过拟合 seed=114514, # 随机数种子 - output_dir=output_dir, # 保存模型检查点和训练日志 + output_dir=train_dir + "/checkpoints", # 保存模型检查点和训练日志 save_strategy="steps", # 按步保存中间权重 save_steps=save_steps, # 使用动态传入的保存步数 - report_to="none", + logging_dir=train_dir + "/logs", # 日志文件存储路径 + report_to="tensorboard", # 使用TensorBoard记录日志 ), ) - - trainer.add_callback(trainer_callback) + + if trainer_callback is not None: + trainer.add_callback(trainer_callback) trainer = train_on_responses_only( trainer,