fix(global_var): 修复全局变量多文件多副本的不统一问题

This commit is contained in:
carry 2025-04-11 18:04:42 +08:00
parent 216bfe39ae
commit ab7897351a
6 changed files with 84 additions and 31 deletions

View File

@ -1,5 +1,6 @@
import gradio as gr import gradio as gr
from global_var import docs, scan_docs_directory, prompt_store from tools import scan_docs_directory
from global_var import get_docs, scan_docs_directory, get_prompt_store
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -7,7 +8,7 @@ def dataset_generate_page():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
# 获取文档列表并设置初始值 # 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")] docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown( doc_dropdown = gr.Dropdown(
@ -21,7 +22,7 @@ def dataset_generate_page():
with gr.Column(): with gr.Column():
# 获取模板列表并设置初始值 # 获取模板列表并设置初始值
prompts = prompt_store.all() prompts = get_prompt_store().all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts] prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None initial_prompt = prompt_choices[0] if prompt_choices else None

View File

@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
from global_var import datasets from global_var import get_datasets
from tinydb import Query from tinydb import Query
def dataset_manage_page(): def dataset_manage_page():
@ -7,7 +7,7 @@ def dataset_manage_page():
gr.Markdown("## 数据集管理") gr.Markdown("## 数据集管理")
with gr.Row(): with gr.Row():
# 获取数据集列表并设置初始值 # 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in datasets.all()] datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown( dataset_dropdown = gr.Dropdown(
@ -33,7 +33,7 @@ def dataset_manage_page():
# 从数据库获取数据集 # 从数据库获取数据集
Dataset = Query() Dataset = Query()
ds = datasets.get(Dataset.name == dataset_name) ds = get_datasets().get(Dataset.name == dataset_name)
if not ds: if not ds:
return {"samples": [], "__type__": "update"} return {"samples": [], "__type__": "update"}

View File

@ -6,7 +6,7 @@ from unsloth import FastLanguageModel
import torch import torch
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import model,tokenizer from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name from tools.model import get_model_name
def model_manage_page(): def model_manage_page():
@ -35,9 +35,8 @@ def model_manage_page():
def load_model(selected_model, max_seq_length, load_in_4bit): def load_model(selected_model, max_seq_length, load_in_4bit):
try: try:
global model, tokenizer
# 判空操作,如果模型已加载,则先卸载 # 判空操作,如果模型已加载,则先卸载
if model is not None: if get_model() is not None:
unload_model() unload_model()
model_path = os.path.join(models_dir, selected_model) model_path = os.path.join(models_dir, selected_model)
@ -46,6 +45,8 @@ def model_manage_page():
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit, load_in_4bit=load_in_4bit,
) )
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载" return f"模型 {get_model_name(model)} 已加载"
except Exception as e: except Exception as e:
return f"加载模型时出错: {str(e)}" return f"加载模型时出错: {str(e)}"
@ -54,20 +55,17 @@ def model_manage_page():
def unload_model(): def unload_model():
try: try:
global model, tokenizer
# 将模型移动到 CPU # 将模型移动到 CPU
model = get_model()
if model is not None: if model is not None:
model.cpu() model.cpu()
# 如果提供了 tokenizer也将其设置为 None
if tokenizer is not None:
tokenizer = None
# 清空 CUDA 缓存 # 清空 CUDA 缓存
torch.cuda.empty_cache() torch.cuda.empty_cache()
# 将模型设置为 None # 将模型和tokenizer设置为 None
model = None set_model(None)
set_tokenizer(None)
return "当前未加载模型" return "当前未加载模型"
except Exception as e: except Exception as e:

View File

@ -3,13 +3,13 @@ import sys
from pathlib import Path from pathlib import Path
from typing import List from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import prompt_store from global_var import get_prompt_store
from schema.prompt import promptTempleta from schema.prompt import promptTempleta
def prompt_manage_page(): def prompt_manage_page():
def get_prompts() -> List[List[str]]: def get_prompts() -> List[List[str]]:
selected_row = None selected_row = None
try: try:
db = prompt_store db = get_prompt_store()
prompts = db.all() prompts = db.all()
return [ return [
[p["id"], p["name"], p["description"], p["content"]] [p["id"], p["name"], p["description"], p["content"]]
@ -20,7 +20,7 @@ def prompt_manage_page():
def add_prompt(name, description, content): def add_prompt(name, description, content):
try: try:
db = prompt_store db = get_prompt_store()
new_prompt = promptTempleta( new_prompt = promptTempleta(
name=name if name else "", name=name if name else "",
description=description if description else "", description=description if description else "",
@ -38,7 +38,7 @@ def prompt_manage_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要编辑的行") raise gr.Error("请先选择要编辑的行")
try: try:
db = prompt_store db = get_prompt_store()
db.update({ db.update({
"name": selected_row[1] if selected_row[1] else "", "name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "", "description": selected_row[2] if selected_row[2] else "",

View File

@ -2,13 +2,13 @@ import gradio as gr
from typing import List from typing import List
from sqlmodel import Session, select from sqlmodel import Session, select
from schema import APIProvider from schema import APIProvider
from global_var import sql_engine from global_var import get_sql_engine
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]: def get_providers() -> List[List[str]]:
selected_row = None selected_row = None
try: # 添加异常处理 try: # 添加异常处理
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all() providers = session.exec(select(APIProvider)).all()
return [ return [
[p.id, p.model_id, p.base_url, p.api_key or ""] [p.id, p.model_id, p.base_url, p.api_key or ""]
@ -19,7 +19,7 @@ def setting_page():
def add_provider(model_id, base_url, api_key): def add_provider(model_id, base_url, api_key):
try: try:
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
new_provider = APIProvider( new_provider = APIProvider(
model_id=model_id if model_id else None, model_id=model_id if model_id else None,
base_url=base_url if base_url else None, base_url=base_url if base_url else None,
@ -37,7 +37,7 @@ def setting_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要编辑的行") raise gr.Error("请先选择要编辑的行")
try: try:
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0]) provider = session.get(APIProvider, selected_row[0])
if not provider: if not provider:
raise gr.Error("找不到选中的记录") raise gr.Error("找不到选中的记录")
@ -56,7 +56,7 @@ def setting_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要删除的行") raise gr.Error("请先选择要删除的行")
try: try:
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0]) provider = session.get(APIProvider, selected_row[0])
if not provider: if not provider:
raise gr.Error("找不到选中的记录") raise gr.Error("找不到选中的记录")

View File

@ -1,10 +1,64 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir") _prompt_store = None
sql_engine = get_sqlite_engine("workdir") _sql_engine = None
docs = scan_docs_directory("workdir") _docs = None
datasets = get_all_dataset("workdir") _datasets = None
model = None def get_prompt_store():
tokenizer = None global _prompt_store
if _prompt_store is None:
_prompt_store = get_prompt_tinydb("workdir")
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
global _sql_engine
if _sql_engine is None:
_sql_engine = get_sqlite_engine("workdir")
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
global _docs
if _docs is None:
_docs = scan_docs_directory("workdir")
return _docs
def set_docs(new_docs):
global _docs
_docs = new_docs
def get_datasets():
global _datasets
if _datasets is None:
_datasets = get_all_dataset("workdir")
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
_model = None
_tokenizer = None
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer