
- 修改训练目录结构,将检查点和日志分开保存 - 添加 TensorBoard 日志记录支持 - 移除自定义 LossCallback 类,简化训练流程 - 更新训练参数和回调机制,提高代码可读性 - 在 requirements.txt 中添加 tensorboardX 依赖
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
import gradio as gr
|
|
import sys
|
|
from tinydb import Query
|
|
from pathlib import Path
|
|
from transformers import TrainerCallback
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
|
from tools import train_model
|
|
|
|
def train_page():
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("## 微调")
|
|
# 获取数据集列表并设置初始值
|
|
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
|
initial_dataset = datasets_list[0] if datasets_list else None
|
|
|
|
dataset_dropdown = gr.Dropdown(
|
|
choices=datasets_list,
|
|
value=initial_dataset, # 设置初始选中项
|
|
label="选择数据集",
|
|
allow_custom_value=True,
|
|
interactive=True
|
|
)
|
|
|
|
# 新增超参数输入组件
|
|
learning_rate_input = gr.Number(value=2e-4, label="学习率")
|
|
per_device_train_batch_size_input = gr.Number(value=1, label="batch size", precision=0)
|
|
epoch_input = gr.Number(value=1, label="epoch", precision=0)
|
|
save_steps_input = gr.Number(value=20, label="保存步数", precision=0) # 新增保存步数输入框
|
|
lora_rank_input = gr.Number(value=16, label="LoRA秩", precision=0) # 新增LoRA秩输入框
|
|
|
|
train_button = gr.Button("开始微调")
|
|
|
|
# 训练状态输出
|
|
output = gr.Textbox(label="训练日志", interactive=False)
|
|
|
|
def start_training(dataset_name, learning_rate, per_device_train_batch_size, epoch, save_steps, lora_rank):
|
|
# 使用动态传入的超参数
|
|
learning_rate = float(learning_rate)
|
|
per_device_train_batch_size = int(per_device_train_batch_size)
|
|
epoch = int(epoch)
|
|
save_steps = int(save_steps) # 新增保存步数参数
|
|
lora_rank = int(lora_rank) # 新增LoRA秩参数
|
|
# 加载数据集
|
|
dataset = get_datasets().get(Query().name == dataset_name)
|
|
dataset = [ds["message"][0] for ds in dataset["dataset_items"]]
|
|
|
|
train_model(get_model(), get_tokenizer(),
|
|
dataset, get_workdir() + "/1",
|
|
learning_rate, per_device_train_batch_size, epoch,
|
|
save_steps, lora_rank)
|
|
|
|
|
|
train_button.click(
|
|
fn=start_training,
|
|
inputs=[
|
|
dataset_dropdown,
|
|
learning_rate_input,
|
|
per_device_train_batch_size_input,
|
|
epoch_input,
|
|
save_steps_input,
|
|
lora_rank_input
|
|
],
|
|
outputs=output
|
|
)
|
|
|
|
return demo
|
|
|
|
if __name__ == "__main__":
|
|
from global_var import init_global_var
|
|
from model_manage_page import model_manage_page
|
|
init_global_var("workdir")
|
|
demo = gr.TabbedInterface([model_manage_page(), train_page()], ["模型管理", "聊天"])
|
|
demo.queue()
|
|
demo.launch() |