113 lines
5.1 KiB
Python
113 lines
5.1 KiB
Python
import subprocess
|
|
import os
|
|
import gradio as gr
|
|
import sys
|
|
from tinydb import Query
|
|
from pathlib import Path
|
|
from transformers import TrainerCallback
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
|
from tools import find_available_port
|
|
from train import train_model
|
|
|
|
def train_page():
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("## 微调")
|
|
# 获取数据集列表并设置初始值
|
|
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
|
initial_dataset = datasets_list[0] if datasets_list else None
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
dataset_dropdown = gr.Dropdown(
|
|
choices=datasets_list,
|
|
value=initial_dataset, # 设置初始选中项
|
|
label="选择数据集",
|
|
allow_custom_value=True,
|
|
interactive=True
|
|
)
|
|
# 新增超参数输入组件
|
|
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
|
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
|
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
|
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
|
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
|
|
|
train_button = gr.Button("开始微调")
|
|
|
|
# 训练状态输出
|
|
output = gr.Textbox(label="训练状态", interactive=False)
|
|
with gr.Column(scale=3):
|
|
# 新增 TensorBoard iframe 显示框
|
|
tensorboard_iframe = gr.HTML(label="TensorBoard 可视化")
|
|
|
|
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
|
# 使用动态传入的超参数
|
|
learning_rate = float(learning_rate)
|
|
per_device_train_batch_size = int(per_device_train_batch_size)
|
|
epoch = int(epoch)
|
|
save_steps = int(save_steps) # 新增保存步数参数
|
|
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
|
|
|
# 加载数据集
|
|
dataset = get_datasets().get(Query().name == dataset_name)
|
|
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
|
|
|
# 扫描 training 文件夹并生成递增目录
|
|
training_dir = get_workdir() + "/training"
|
|
os.makedirs(training_dir, exist_ok=True) # 确保 training 文件夹存在
|
|
existing_dirs = [d for d in os.listdir(training_dir) if d.isdigit()]
|
|
next_dir_number = max([int(d) for d in existing_dirs], default=0) + 1
|
|
new_training_dir = os.path.join(training_dir, str(next_dir_number))
|
|
|
|
tensorboard_port = find_available_port(6006) # 从默认端口 6006 开始检测
|
|
print(f"TensorBoard 将使用端口: {tensorboard_port}")
|
|
|
|
tensorboard_logdir = os.path.join(new_training_dir, "logs")
|
|
os.makedirs(tensorboard_logdir, exist_ok=True) # 确保日志目录存在
|
|
tensorboard_process = subprocess.Popen(
|
|
["tensorboard", "--logdir", tensorboard_logdir, "--port", str(tensorboard_port)],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE
|
|
)
|
|
print("TensorBoard 已启动,日志目录:", tensorboard_logdir)
|
|
|
|
# 动态生成 TensorBoard iframe
|
|
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
|
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
|
|
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
|
|
|
try:
|
|
train_model(get_model(), get_tokenizer(),
|
|
dataset, new_training_dir,
|
|
learning_rate, per_device_train_batch_size, epoch,
|
|
save_steps, lora_rank)
|
|
except Exception as e:
|
|
raise gr.Error(str(e))
|
|
finally:
|
|
# 确保训练结束后终止 TensorBoard 子进程
|
|
tensorboard_process.terminate()
|
|
print("TensorBoard 子进程已终止")
|
|
|
|
train_button.click(
|
|
fn=start_training,
|
|
inputs=[
|
|
dataset_dropdown,
|
|
learning_rate_input,
|
|
per_device_train_batch_size_input,
|
|
epoch_input,
|
|
save_steps_input,
|
|
lora_rank_input
|
|
],
|
|
outputs=[output, tensorboard_iframe] # 更新输出以包含 iframe
|
|
)
|
|
|
|
return demo
|
|
|
|
if __name__ == "__main__":
|
|
from global_var import init_global_var
|
|
from model_manage_page import model_manage_page
|
|
init_global_var("workdir")
|
|
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
|
demo.queue()
|
|
demo.launch() |