feat(train_page): 添加 TensorBoard 可视化
- 在训练页面添加 TensorBoard iframe 显示框 - 实现动态生成 TensorBoard iframe 功能 - 更新训练按钮点击事件,同时更新 TensorBoard iframe
This commit is contained in:
parent
664944f0c5
commit
aa758e3c2a
@ -25,6 +25,8 @@ def train_page():
|
||||
interactive=True
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
# 新增超参数输入组件
|
||||
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
||||
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
||||
@ -36,6 +38,9 @@ def train_page():
|
||||
|
||||
# 训练状态输出
|
||||
output = gr.Textbox(label="训练日志", interactive=False)
|
||||
with gr.Column(scale=3):
|
||||
# 新增 TensorBoard iframe 显示框
|
||||
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
|
||||
|
||||
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
||||
# 使用动态传入的超参数
|
||||
@ -68,6 +73,11 @@ def train_page():
|
||||
)
|
||||
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
|
||||
|
||||
# 动态生成 TensorBoard iframe
|
||||
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>'
|
||||
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
||||
|
||||
try:
|
||||
train_model(get_model(), get_tokenizer(),
|
||||
dataset, new_training_dir,
|
||||
@ -88,7 +98,7 @@ def train_page():
|
||||
save_steps_input,
|
||||
lora_rank_input
|
||||
],
|
||||
outputs=output
|
||||
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
|
||||
)
|
||||
|
||||
return demo
|
||||
|
Loading…
x
Reference in New Issue
Block a user