diff --git a/frontend/train_page.py b/frontend/train_page.py index 9ad226f..a8f261c 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -1,8 +1,8 @@ import gradio as gr import sys -import torch from tinydb import Query from pathlib import Path +from transformers import TrainerCallback sys.path.append(str(Path(__file__).resolve().parent.parent)) from global_var import get_model, get_tokenizer, get_datasets, get_workdir @@ -45,7 +45,31 @@ def train_page(): # 加载数据集 dataset = get_datasets().get(Query().name == dataset_name) dataset = [ds["message"][0] for ds in dataset["dataset_items"]] - train_model(get_model(), get_tokenizer(), dataset, get_workdir(), learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank) + + class LossCallback(TrainerCallback): + def __init__(self): + self.loss_data = [] + self.log_text = "" + self.last_output = {"text": "", "plot": None} + def on_log(self, args, state, control, logs=None, **kwargs): + if "loss" in logs: + self.loss_data.append({ + "step": state.global_step, + "loss": float(logs["loss"]) + }) + self.log_text += f"Step {state.global_step}: loss={logs['loss']:.4f}\n" + # 添加以下两行print语句 + print(f"Current Step: {state.global_step}") + print(f"Loss Value: {logs['loss']:.4f}") + self.last_output = { + "text": self.log_text, + } + # 不返回 control,避免干预训练过程 + + train_model(get_model(), get_tokenizer(), + dataset, get_workdir()+"/checkpoint", + learning_rate, per_device_train_batch_size, epoch, + save_steps, lora_rank, LossCallback) train_button.click( diff --git a/tools/model.py b/tools/model.py index b06c0bc..062b222 100644 --- a/tools/model.py +++ b/tools/model.py @@ -40,7 +40,8 @@ def train_model( per_device_train_batch_size: int, epoch: int, save_steps: int, - lora_rank: int + lora_rank: int, + trainer_callback ) -> None: # 模型配置参数 dtype = None # 数据类型,None表示自动选择 @@ -115,10 +116,12 @@ def train_model( output_dir=output_dir, # 保存模型检查点和训练日志 save_strategy="steps", # 按步保存中间权重 save_steps=save_steps, # 使用动态传入的保存步数 - # report_to="tensorboard", # 将信息输出到tensorboard + report_to="none", ), ) + trainer.add_callback(trainer_callback) + trainer = train_on_responses_only( trainer, instruction_part = "<|im_start|>user\n",