From 7a4388c928d9345843559235c38be2711e9ee0d7 Mon Sep 17 00:00:00 2001 From: carry Date: Wed, 23 Apr 2025 14:09:02 +0800 Subject: [PATCH] =?UTF-8?q?featmodel):=20=E6=B7=BB=E5=8A=A0=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E6=A8=A1=E5=BC=8F=E9=80=89=E6=8B=A9=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在模型管理页面中新增保存模式选择功能,用户可以通过下拉菜单选择不同的保存模式(如默认、合并16位、合并4位等)。同时,将保存模型的逻辑抽离到独立的`save_model.py`文件中,以提高代码的可维护性和复用性。 --- frontend/model_manage_page.py | 24 ++++++++---------- train/save_model.py | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 14 deletions(-) create mode 100644 train/save_model.py diff --git a/frontend/model_manage_page.py b/frontend/model_manage_page.py index 7483049..12f6227 100644 --- a/frontend/model_manage_page.py +++ b/frontend/model_manage_page.py @@ -30,6 +30,11 @@ def model_manage_page(): with gr.Row(): with gr.Column(scale=3): save_model_name_input = gr.Textbox(label="保存模型名称", placeholder="输入模型保存名称") + save_method_dropdown = gr.Dropdown( + choices=["default", "merged_16bit", "merged_4bit", "lora", "gguf", "gguf_q4_k_m", "gguf_f16"], + label="保存模式", + value="default" + ) with gr.Column(scale=1): save_button = gr.Button("保存模型", variant="secondary") @@ -73,21 +78,12 @@ def model_manage_page(): unload_button.click(fn=unload_model, inputs=None, outputs=state_output) - def save_model(save_model_name): - try: - global model, tokenizer - if model is None: - return "没有加载的模型可保存" - - save_path = os.path.join(models_dir, save_model_name) - os.makedirs(save_path, exist_ok=True) - model.save_pretrained(save_path) - tokenizer.save_pretrained(save_path) - return f"模型已保存到 {save_path}" - except Exception as e: - return f"保存模型时出错: {str(e)}" + from train.save_model import save_model_to_dir + + def save_model(save_model_name, save_method): + return save_model_to_dir(save_model_name, models_dir, get_model(), get_tokenizer(), save_method) - save_button.click(fn=save_model, inputs=save_model_name_input, outputs=state_output) + save_button.click(fn=save_model, inputs=[save_model_name_input, save_method_dropdown], outputs=state_output) def refresh_model_list(): try: diff --git a/train/save_model.py b/train/save_model.py new file mode 100644 index 0000000..e99ca69 --- /dev/null +++ b/train/save_model.py @@ -0,0 +1,48 @@ +import os +from global_var import get_model, get_tokenizer + +def save_model_to_dir(save_model_name, models_dir, model, tokenizer, save_method="default"): + """ + 保存模型到指定目录 + :param save_model_name: 要保存的模型名称 + :param models_dir: 模型保存的基础目录 + :param model: 要保存的模型 + :param tokenizer: 要保存的tokenizer + :param save_method: 保存模式选项 + - "default": 默认保存方式 + - "merged_16bit": 合并为16位 + - "merged_4bit": 合并为4位 + - "lora": 仅LoRA适配器 + - "gguf": 保存为GGUF格式 + - "gguf_q4_k_m": 保存为q4_k_m GGUF格式 + - "gguf_f16": 保存为16位GGUF格式 + :return: 保存结果消息或错误信息 + """ + try: + if model is None: + return "没有加载的模型可保存" + + save_path = os.path.join(models_dir, save_model_name) + os.makedirs(save_path, exist_ok=True) + + if save_method == "default": + model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + elif save_method == "merged_16bit": + model.save_pretrained_merged(save_path, tokenizer, save_method="merged_16bit") + elif save_method == "merged_4bit": + model.save_pretrained_merged(save_path, tokenizer, save_method="merged_4bit_forced") + elif save_method == "lora": + model.save_pretrained_merged(save_path, tokenizer, save_method="lora") + elif save_method == "gguf": + model.save_pretrained_gguf(save_path, tokenizer) + elif save_method == "gguf_q4_k_m": + model.save_pretrained_gguf(save_path, tokenizer, quantization_method="q4_k_m") + elif save_method == "gguf_f16": + model.save_pretrained_gguf(save_path, tokenizer, quantization_method="f16") + else: + return f"不支持的保存模式: {save_method}" + + return f"模型已保存到 {save_path} (模式: {save_method})" + except Exception as e: + return f"保存模型时出错: {str(e)}" \ No newline at end of file