From aa758e3c2a48ad0c525b53e9be1c18bedb774e17 Mon Sep 17 00:00:00 2001 From: carry Date: Mon, 14 Apr 2025 23:28:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(train=5Fpage):=20=E6=B7=BB=E5=8A=A0=20Tens?= =?UTF-8?q?orBoard=20=E5=8F=AF=E8=A7=86=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在训练页面添加 TensorBoard iframe 显示框 - 实现动态生成 TensorBoard iframe 功能 - 更新训练按钮点击事件,同时更新 TensorBoard iframe --- frontend/train_page.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/frontend/train_page.py b/frontend/train_page.py index 188f901..2acbbd5 100644 --- a/frontend/train_page.py +++ b/frontend/train_page.py @@ -25,18 +25,23 @@ def train_page(): interactive=True ) - # 新增超参数输入组件 - learning_rate_input = gr.Number(value=2e-4, label="学习率") - per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) - epoch_input = gr.Number(value=1, label="epoch", precision=0) - save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 - lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框 - - train_button = gr.Button("开始微调") - - # 训练状态输出 - output = gr.Textbox(label="训练日志", interactive=False) - + with gr.Row(): + with gr.Column(scale=1): + # 新增超参数输入组件 + learning_rate_input = gr.Number(value=2e-4, label="学习率") + per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) + epoch_input = gr.Number(value=1, label="epoch", precision=0) + save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 + lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框 + + train_button = gr.Button("开始微调") + + # 训练状态输出 + output = gr.Textbox(label="训练日志", interactive=False) + with gr.Column(scale=3): + # 新增 TensorBoard iframe 显示框 + tensorboard_iframe = gr.HTML(label="TensorBoard 可视化") + def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): # 使用动态传入的超参数 learning_rate = float(learning_rate) @@ -68,6 +73,11 @@ def train_page(): ) print("TensorBoard 已启动,日志目录:", tensorboard_logdir) + # 动态生成 TensorBoard iframe + tensorboard_url = f"http://localhost:{tensorboard_port}" + tensorboard_iframe_value = f'' + yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html + try: train_model(get_model(), get_tokenizer(), dataset, new_training_dir, @@ -88,7 +98,7 @@ def train_page(): save_steps_input, lora_rank_input ], - outputs=output + outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe ) return demo