gzhu-biyesheji/frontend/model_manage_page.py
carry 402bc73dce feat(model_manage_page): 增加模型保存和刷新功能
- 新增保存模型功能,用户可以输入模型名称并保存当前加载的模型
- 添加刷新模型列表按钮,用户可以随时更新模型下拉菜单中的选项
- 优化页面布局,使按钮和输入框更加合理地排列
2025-04-10 20:18:03 +08:00

108 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import os # 导入os模块以便扫描文件夹
import sys
from pathlib import Path
from unsloth import FastLanguageModel
import torch
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import model,tokenizer
def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
models_dir = os.path.join(workdir, "models")
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
with gr.Blocks() as demo:
gr.Markdown("## 模型管理")
output_text = gr.Textbox(label="当前状态", interactive=False)
with gr.Row():
with gr.Column(scale=3):
dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中并设置为可选
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
with gr.Column(scale=1):
load_button = gr.Button("加载模型", variant="primary")
unload_button = gr.Button("卸载模型", variant="stop")
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
with gr.Row():
with gr.Column(scale=3):
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
with gr.Column(scale=1):
save_button = gr.Button("保存模型", variant="secondary")
def load_model(selected_model, max_seq_length, load_in_4bit):
try:
global model, tokenizer
# 判空操作,如果模型已加载,则先卸载
if model is not None:
unload_model()
model_path = os.path.join(models_dir, selected_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
return f"模型 {selected_model} 已加载"
except Exception as e:
return f"加载模型时出错: {str(e)}"
load_button.click(fn=load_model, inputs=[dropdown, max_seq_length_input, load_in_4bit_input], outputs=output_text)
def unload_model():
try:
global model, tokenizer
# 将模型移动到 CPU
if model is not None:
model.cpu()
# 如果提供了 tokenizer也将其设置为 None
if tokenizer is not None:
tokenizer = None
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 将模型设置为 None
model = None
return "模型已卸载"
except Exception as e:
return f"卸载模型时出错: {str(e)}"
unload_button.click(fn=unload_model, inputs=None, outputs=output_text)
def save_model(save_model_name):
try:
global model, tokenizer
if model is None:
return "没有加载的模型可保存"
save_path = os.path.join(models_dir, save_model_name)
os.makedirs(save_path, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
return f"模型已保存到 {save_path}"
except Exception as e:
return f"保存模型时出错: {str(e)}"
save_button.click(fn=save_model, inputs=save_model_name_input, outputs=output_text)
def refresh_model_list():
try:
nonlocal model_folders
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
return gr.Dropdown(choices=model_folders)
except Exception as e:
return f"刷新模型列表时出错: {str(e)}"
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=dropdown)
return demo
if __name__ == "__main__":
demo = model_manage_page()
demo.queue()
demo.launch()