feat(model_manage_page): 增加模型保存和刷新功能

- 新增保存模型功能,用户可以输入模型名称并保存当前加载的模型
- 添加刷新模型列表按钮,用户可以随时更新模型下拉菜单中的选项
- 优化页面布局,使按钮和输入框更加合理地排列
This commit is contained in:
carry 2025-04-10 20:18:03 +08:00
parent bb5851f800
commit 402bc73dce

View File

@ -15,13 +15,22 @@ def model_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 模型管理")
dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中并设置为可选
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
output_text = gr.Textbox(label="当前状态", interactive=False)
with gr.Row():
load_button = gr.Button("加载模型", variant="primary")
unload_button = gr.Button("卸载模型", variant="stop")
output_text = gr.Textbox(label="操作结果", interactive=False)
with gr.Column(scale=3):
dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中并设置为可选
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
with gr.Column(scale=1):
load_button = gr.Button("加载模型", variant="primary")
unload_button = gr.Button("卸载模型", variant="stop")
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
with gr.Row():
with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary")
def load_model(selected_model, max_seq_length, load_in_4bit):
try:
@ -65,6 +74,32 @@ def model_manage_page():
unload_button.click(fn=unload_model, inputs=None, outputs=output_text)
def save_model(save_model_name):
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=output_text)
def refresh_model_list():
try:
nonlocal model_folders
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
return gr.Dropdown(choices=model_folders)
except Exception as e:
return f"刷新模型列表时出错: {str(e)}"
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=dropdown)
return demo
if __name__ == "__main__":