From ab7897351a599397c3ee716247cf856e5ad4111a Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Fri, 11 Apr 2025 18:04:42 +0800 Subject: [PATCH] =?UTF-8?q?fix(global=5Fvar):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E5=8F=98=E9=87=8F=E5=A4=9A=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=A4=9A=E5=89=AF=E6=9C=AC=E7=9A=84=E4=B8=8D=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/dataset_generate_page.py | 7 ++-- frontend/dataset_manage_page.py | 6 +-- frontend/model_manage_page.py | 18 ++++----- frontend/prompt_manage_page.py | 8 ++-- frontend/setting_page.py | 10 ++--- global_var.py | 66 ++++++++++++++++++++++++++++--- 6 files changed, 84 insertions(+), 31 deletions(-) diff --git a/frontend/dataset_generate_page.py b/frontend/dataset_generate_page.py index 2cfcacd..ad344fe 100644 --- a/frontend/dataset_generate_page.py +++ b/frontend/dataset_generate_page.py @@ -1,5 +1,6 @@ import gradio as gr -from global_var import docs, scan_docs_directory, prompt_store +from tools import scan_docs_directory +from global_var import get_docs, scan_docs_directory, get_prompt_store def dataset_generate_page(): with gr.Blocks() as demo: @@ -7,7 +8,7 @@ def dataset_generate_page(): with gr.Row(): with gr.Column(): # 获取文档列表并设置初始值 - docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")] + docs_list = [str(doc.name) for doc in get_docs()] initial_doc = docs_list[0] if docs_list else None doc_dropdown = gr.Dropdown( @@ -21,7 +22,7 @@ def dataset_generate_page(): with gr.Column(): # 获取模板列表并设置初始值 - prompts = prompt_store.all() + prompts = get_prompt_store().all() prompt_choices = [f"{p['id']} {p['name']}" for p in prompts] initial_prompt = prompt_choices[0] if prompt_choices else None diff --git a/frontend/dataset_manage_page.py b/frontend/dataset_manage_page.py index 0ef45f6..052734d 100644 --- a/frontend/dataset_manage_page.py +++ b/frontend/dataset_manage_page.py @@ -1,5 +1,5 @@ import gradio as gr -from global_var import datasets +from global_var import get_datasets from tinydb import Query def dataset_manage_page(): @@ -7,7 +7,7 @@ def dataset_manage_page(): gr.Markdown("## 数据集管理") with gr.Row(): # 获取数据集列表并设置初始值 - datasets_list = [str(ds["name"]) for ds in datasets.all()] + datasets_list = [str(ds["name"]) for ds in get_datasets().all()] initial_dataset = datasets_list[0] if datasets_list else None dataset_dropdown = gr.Dropdown( @@ -33,7 +33,7 @@ def dataset_manage_page(): # 从数据库获取数据集 Dataset = Query() - ds = datasets.get(Dataset.name == dataset_name) + ds = get_datasets().get(Dataset.name == dataset_name) if not ds: return {"samples": [], "__type__": "update"} diff --git a/frontend/model_manage_page.py b/frontend/model_manage_page.py index aa3b302..fe2f0ce 100644 --- a/frontend/model_manage_page.py +++ b/frontend/model_manage_page.py @@ -6,7 +6,7 @@ from unsloth import FastLanguageModel import torch sys.path.append(str(Path(__file__).resolve().parent.parent)) -from global_var import model,tokenizer +from global_var import get_model, get_tokenizer, set_model, set_tokenizer from tools.model import get_model_name def model_manage_page(): @@ -35,9 +35,8 @@ def model_manage_page(): def load_model(selected_model, max_seq_length, load_in_4bit): try: - global model, tokenizer # 判空操作,如果模型已加载,则先卸载 - if model is not None: + if get_model() is not None: unload_model() model_path = os.path.join(models_dir, selected_model) @@ -46,6 +45,8 @@ def model_manage_page(): max_seq_length=max_seq_length, load_in_4bit=load_in_4bit, ) + set_model(model) + set_tokenizer(tokenizer) return f"模型 {get_model_name(model)} 已加载" except Exception as e: return f"加载模型时出错: {str(e)}" @@ -54,20 +55,17 @@ def model_manage_page(): def unload_model(): try: - global model, tokenizer # 将模型移动到 CPU + model = get_model() if model is not None: model.cpu() - # 如果提供了 tokenizer,也将其设置为 None - if tokenizer is not None: - tokenizer = None - # 清空 CUDA 缓存 torch.cuda.empty_cache() - # 将模型设置为 None - model = None + # 将模型和tokenizer设置为 None + set_model(None) + set_tokenizer(None) return "当前未加载模型" except Exception as e: diff --git a/frontend/prompt_manage_page.py b/frontend/prompt_manage_page.py index a2426c7..b856ead 100644 --- a/frontend/prompt_manage_page.py +++ b/frontend/prompt_manage_page.py @@ -3,13 +3,13 @@ import sys from pathlib import Path from typing import List sys.path.append(str(Path(__file__).resolve().parent.parent)) -from global_var import prompt_store +from global_var import get_prompt_store from schema.prompt import promptTempleta def prompt_manage_page(): def get_prompts() -> List[List[str]]: selected_row = None try: - db = prompt_store + db = get_prompt_store() prompts = db.all() return [ [p["id"], p["name"], p["description"], p["content"]] @@ -20,7 +20,7 @@ def prompt_manage_page(): def add_prompt(name, description, content): try: - db = prompt_store + db = get_prompt_store() new_prompt = promptTempleta( name=name if name else "", description=description if description else "", @@ -38,7 +38,7 @@ def prompt_manage_page(): if not selected_row: raise gr.Error("请先选择要编辑的行") try: - db = prompt_store + db = get_prompt_store() db.update({ "name": selected_row[1] if selected_row[1] else "", "description": selected_row[2] if selected_row[2] else "", diff --git a/frontend/setting_page.py b/frontend/setting_page.py index 1e827de..3dba298 100644 --- a/frontend/setting_page.py +++ b/frontend/setting_page.py @@ -2,13 +2,13 @@ import gradio as gr from typing import List from sqlmodel import Session, select from schema import APIProvider -from global_var import sql_engine +from global_var import get_sql_engine def setting_page(): def get_providers() -> List[List[str]]: selected_row = None try: # 添加异常处理 - with Session(sql_engine) as session: + with Session(get_sql_engine()) as session: providers = session.exec(select(APIProvider)).all() return [ [p.id, p.model_id, p.base_url, p.api_key or ""] @@ -19,7 +19,7 @@ def setting_page(): def add_provider(model_id, base_url, api_key): try: - with Session(sql_engine) as session: + with Session(get_sql_engine()) as session: new_provider = APIProvider( model_id=model_id if model_id else None, base_url=base_url if base_url else None, @@ -37,7 +37,7 @@ def setting_page(): if not selected_row: raise gr.Error("请先选择要编辑的行") try: - with Session(sql_engine) as session: + with Session(get_sql_engine()) as session: provider = session.get(APIProvider, selected_row[0]) if not provider: raise gr.Error("找不到选中的记录") @@ -56,7 +56,7 @@ def setting_page(): if not selected_row: raise gr.Error("请先选择要删除的行") try: - with Session(sql_engine) as session: + with Session(get_sql_engine()) as session: provider = session.get(APIProvider, selected_row[0]) if not provider: raise gr.Error("找不到选中的记录") diff --git a/global_var.py b/global_var.py index 328cb13..53566d5 100644 --- a/global_var.py +++ b/global_var.py @@ -1,10 +1,64 @@ from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from tools import scan_docs_directory -prompt_store = get_prompt_tinydb("workdir") -sql_engine = get_sqlite_engine("workdir") -docs = scan_docs_directory("workdir") -datasets = get_all_dataset("workdir") +_prompt_store = None +_sql_engine = None +_docs = None +_datasets = None -model = None -tokenizer = None \ No newline at end of file +def get_prompt_store(): + global _prompt_store + if _prompt_store is None: + _prompt_store = get_prompt_tinydb("workdir") + return _prompt_store + +def set_prompt_store(new_prompt_store): + global _prompt_store + _prompt_store = new_prompt_store + +def get_sql_engine(): + global _sql_engine + if _sql_engine is None: + _sql_engine = get_sqlite_engine("workdir") + return _sql_engine + +def set_sql_engine(new_sql_engine): + global _sql_engine + _sql_engine = new_sql_engine + +def get_docs(): + global _docs + if _docs is None: + _docs = scan_docs_directory("workdir") + return _docs + +def set_docs(new_docs): + global _docs + _docs = new_docs + +def get_datasets(): + global _datasets + if _datasets is None: + _datasets = get_all_dataset("workdir") + return _datasets + +def set_datasets(new_datasets): + global _datasets + _datasets = new_datasets + +_model = None +_tokenizer = None + +def get_model(): + return _model + +def set_model(new_model): + global _model + _model = new_model + +def get_tokenizer(): + return _tokenizer + +def set_tokenizer(new_tokenizer): + global _tokenizer + _tokenizer = new_tokenizer \ No newline at end of file