feat(model_manage_page): 实现模型加载和卸载功能
- 添加模型加载和卸载按钮 - 实现模型加载和卸载的逻辑 - 添加相关模块的导入 - 扫描模型目录并显示在下拉框中
This commit is contained in:
parent
4b465ec917
commit
a407fa1f76
@ -1,9 +1,73 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import os # 导入os模块以便扫描文件夹
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
from global_var import model,tokenizer
|
||||||
|
|
||||||
def model_manage_page():
|
def model_manage_page():
|
||||||
|
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||||
|
models_dir = os.path.join(workdir, "models")
|
||||||
|
model_folders = [name for name in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, name))] # 扫描models文件夹下的所有子文件夹
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## 模型管理")
|
gr.Markdown("## 模型管理")
|
||||||
|
dropdown = gr.Dropdown(choices=model_folders, label="选择模型", interactive=True) # 将子文件夹列表添加到Dropdown组件中,并设置为可选
|
||||||
|
max_seq_length_input = gr.Number(label="最大序列长度", value=4096, precision=0)
|
||||||
|
load_in_4bit_input = gr.Checkbox(label="使用4位量化", value=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
load_button = gr.Button("加载模型", variant="primary")
|
||||||
pass
|
unload_button = gr.Button("卸载模型", variant="stop")
|
||||||
return demo
|
output_text = gr.Textbox(label="操作结果", interactive=False)
|
||||||
|
|
||||||
|
def load_model(selected_model, max_seq_length, load_in_4bit):
|
||||||
|
try:
|
||||||
|
global model, tokenizer
|
||||||
|
# 判空操作,如果模型已加载,则先卸载
|
||||||
|
if model is not None:
|
||||||
|
unload_model()
|
||||||
|
|
||||||
|
model_path = os.path.join(models_dir, selected_model)
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=model_path,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
load_in_4bit=load_in_4bit,
|
||||||
|
)
|
||||||
|
return f"模型 {selected_model} 已加载"
|
||||||
|
except Exception as e:
|
||||||
|
return f"加载模型时出错: {str(e)}"
|
||||||
|
|
||||||
|
load_button.click(fn=load_model, inputs=[dropdown, max_seq_length_input, load_in_4bit_input], outputs=output_text)
|
||||||
|
|
||||||
|
def unload_model():
|
||||||
|
try:
|
||||||
|
global model, tokenizer
|
||||||
|
# 将模型移动到 CPU
|
||||||
|
if model is not None:
|
||||||
|
model.cpu()
|
||||||
|
|
||||||
|
# 如果提供了 tokenizer,也将其设置为 None
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
# 清空 CUDA 缓存
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# 将模型设置为 None
|
||||||
|
model = None
|
||||||
|
|
||||||
|
return "模型已卸载"
|
||||||
|
except Exception as e:
|
||||||
|
return f"卸载模型时出错: {str(e)}"
|
||||||
|
|
||||||
|
unload_button.click(fn=unload_model, inputs=None, outputs=output_text)
|
||||||
|
|
||||||
|
return demo
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo = model_manage_page()
|
||||||
|
demo.queue()
|
||||||
|
demo.launch()
|
Loading…
x
Reference in New Issue
Block a user