From 402bc73dce867d9a683b034902cc3d5541dd135e Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Thu, 10 Apr 2025 20:18:03 +0800 Subject: [PATCH] =?UTF-8?q?feat(model=5Fmanage=5Fpage):=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=A8=A1=E5=9E=8B=E4=BF=9D=E5=AD=98=E5=92=8C=E5=88=B7?= =?UTF-8?q?=E6=96=B0=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增保存模型功能,用户可以输入模型名称并保存当前加载的模型 - 添加刷新模型列表按钮,用户可以随时更新模型下拉菜单中的选项 - 优化页面布局,使按钮和输入框更加合理地排列 --- frontend/model_manage_page.py | 47 ++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/frontend/model_manage_page.py b/frontend/model_manage_page.py index f35764e..88390fb 100644 --- a/frontend/model_manage_page.py +++ b/frontend/model_manage_page.py @@ -15,13 +15,22 @@ def model_manage_page(): with gr.Blocks() as demo: gr.Markdown("## 模型管理") - dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选 - max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0) - load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True) + output_text = gr.Textbox(label="当前状态", interactive=False) with gr.Row(): - load_button = gr.Button("加载模型", variant="primary") - unload_button = gr.Button("卸载模型", variant="stop") - output_text = gr.Textbox(label="操作结果", interactive=False) + with gr.Column(scale=3): + dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选 + max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0) + load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True) + with gr.Column(scale=1): + load_button = gr.Button("加载模型", variant="primary") + unload_button = gr.Button("卸载模型", variant="stop") + refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮 + + with gr.Row(): + with gr.Column(scale=3): + save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称") + with gr.Column(scale=1): + save_button = gr.Button("保存模型", variant="secondary") def load_model(selected_model, max_seq_length, load_in_4bit): try: @@ -65,6 +74,32 @@ def model_manage_page(): unload_button.click(fn=unload_model, inputs=None, outputs=output_text) + def save_model(save_model_name): + try: + global model, tokenizer + if model is None: + return "没有加载的模型可保存" + + save_path = os.path.join(models_dir, save_model_name) + os.makedirs(save_path, exist_ok=True) + model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + return f"模型已保存到 {save_path}" + except Exception as e: + return f"保存模型时出错: {str(e)}" + + save_button.click(fn=save_model, inputs=save_model_name_input, outputs=output_text) + + def refresh_model_list(): + try: + nonlocal model_folders + model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] + return gr.Dropdown(choices=model_folders) + except Exception as e: + return f"刷新模型列表时出错: {str(e)}" + + refresh_button.click(fn=refresh_model_list, inputs=None, outputs=dropdown) + return demo if __name__ == "__main__":