
在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。
103 lines
4.5 KiB
Python
103 lines
4.5 KiB
Python
import gradio as gr
|
||
import os # 导入os模块以便扫描文件夹
|
||
import sys
|
||
from pathlib import Path
|
||
from unsloth import FastLanguageModel
|
||
import torch
|
||
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||
from train import get_model_name
|
||
|
||
def model_manage_page():
|
||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||
models_dir = os.path.join(workdir, "models")
|
||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
|
||
|
||
with gr.Blocks() as demo:
|
||
gr.Markdown("## 模型管理")
|
||
state_output = gr.Label(label="当前状态",value="当前未加载模型") # 将 Textbox 改为 Label
|
||
with gr.Row():
|
||
with gr.Column(scale=3):
|
||
model_select_dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选
|
||
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
|
||
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
|
||
with gr.Column(scale=1):
|
||
load_button = gr.Button("加载模型", variant="primary")
|
||
unload_button = gr.Button("卸载模型", variant="stop")
|
||
refresh_button = gr.Button("刷新模型列表", variant="secondary") # 新增刷新按钮
|
||
|
||
with gr.Row():
|
||
with gr.Column(scale=3):
|
||
save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称")
|
||
save_method_dropdown = gr.Dropdown(
|
||
choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"],
|
||
label="保存模式",
|
||
value="default"
|
||
)
|
||
with gr.Column(scale=1):
|
||
save_button = gr.Button("保存模型", variant="secondary")
|
||
|
||
def load_model(selected_model, max_seq_length, load_in_4bit):
|
||
try:
|
||
# 判空操作,如果模型已加载,则先卸载
|
||
if get_model() is not None:
|
||
unload_model()
|
||
|
||
model_path = os.path.join(models_dir, selected_model)
|
||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||
model_name=model_path,
|
||
max_seq_length=max_seq_length,
|
||
load_in_4bit=load_in_4bit,
|
||
)
|
||
set_model(model)
|
||
set_tokenizer(tokenizer)
|
||
return f"模型 {get_model_name(model)} 已加载"
|
||
except Exception as e:
|
||
return f"加载模型时出错: {str(e)}"
|
||
|
||
load_button.click(fn=load_model, inputs=[model_select_dropdown, max_seq_length_input, load_in_4bit_input], outputs=state_output)
|
||
|
||
def unload_model():
|
||
try:
|
||
# 将模型移动到 CPU
|
||
model = get_model()
|
||
if model is not None:
|
||
model.cpu()
|
||
|
||
# 清空 CUDA 缓存
|
||
torch.cuda.empty_cache()
|
||
|
||
# 将模型和tokenizer设置为 None
|
||
set_model(None)
|
||
set_tokenizer(None)
|
||
|
||
return "当前未加载模型"
|
||
except Exception as e:
|
||
return f"卸载模型时出错: {str(e)}"
|
||
|
||
unload_button.click(fn=unload_model, inputs=None, outputs=state_output)
|
||
|
||
from train.save_model import save_model_to_dir
|
||
|
||
def save_model(save_model_name, save_method):
|
||
return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method)
|
||
|
||
save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output)
|
||
|
||
def refresh_model_list():
|
||
try:
|
||
nonlocal model_folders
|
||
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))]
|
||
return gr.Dropdown(choices=model_folders)
|
||
except Exception as e:
|
||
return f"刷新模型列表时出错: {str(e)}"
|
||
|
||
refresh_button.click(fn=refresh_model_list, inputs=None, outputs=model_select_dropdown)
|
||
|
||
return demo
|
||
|
||
if __name__ == "__main__":
|
||
demo = model_manage_page()
|
||
demo.queue()
|
||
demo.launch() |