feat(train_page): 添加 TensorBoard 可视化

- 在训练页面添加 TensorBoard iframe 显示框
- 实现动态生成 TensorBoard iframe 功能
- 更新训练按钮点击事件,同时更新 TensorBoard iframe
This commit is contained in:
carry 2025-04-14 23:28:43 +08:00
parent 664944f0c5
commit aa758e3c2a

View File

@ -25,18 +25,23 @@ def train_page():
interactive=True interactive=True
) )
# 新增超参数输入组件 with gr.Row():
learning_rate_input = gr.Number(value=2e-4, label="学习率") with gr.Column(scale=1):
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0) # 新增超参数输入组件
epoch_input = gr.Number(value=1, label="epoch", precision=0) learning_rate_input = gr.Number(value=2e-4, label="学习率")
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框 per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框 epoch_input = gr.Number(value=1, label="epoch", precision=0)
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
train_button = gr.Button("开始微调") lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
# 训练状态输出 train_button = gr.Button("开始微调")
output = gr.Textbox(label="训练日志", interactive=False)
# 训练状态输出
output = gr.Textbox(label="训练日志", interactive=False)
with gr.Column(scale=3):
# 新增 TensorBoard iframe 显示框
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank): def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
# 使用动态传入的超参数 # 使用动态传入的超参数
learning_rate = float(learning_rate) learning_rate = float(learning_rate)
@ -68,6 +73,11 @@ def train_page():
) )
print("TensorBoard 已启动,日志目录:", tensorboard_logdir) print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
# 动态生成 TensorBoard iframe
tensorboard_url = f"http://localhost:{tensorboard_port}"
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>'
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
try: try:
train_model(get_model(), get_tokenizer(), train_model(get_model(), get_tokenizer(),
dataset, new_training_dir, dataset, new_training_dir,
@ -88,7 +98,7 @@ def train_page():
save_steps_input, save_steps_input,
lora_rank_input lora_rank_input
], ],
outputs=output outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
) )
return demo return demo