fix(global_var): 修复全局变量多文件多副本的不统一问题

This commit is contained in:
carry 2025-04-11 18:04:42 +08:00
parent 216bfe39ae
commit ab7897351a
6 changed files with 84 additions and 31 deletions

View File

@ -1,5 +1,6 @@
import gradio as gr
from global_var import docs, scan_docs_directory, prompt_store
from tools import scan_docs_directory
from global_var import get_docs, scan_docs_directory, get_prompt_store
def dataset_generate_page():
with gr.Blocks() as demo:
@ -7,7 +8,7 @@ def dataset_generate_page():
with gr.Row():
with gr.Column():
# 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")]
docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown(
@ -21,7 +22,7 @@ def dataset_generate_page():
with gr.Column():
# 获取模板列表并设置初始值
prompts = prompt_store.all()
prompts = get_prompt_store().all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None

View File

@ -1,5 +1,5 @@
import gradio as gr
from global_var import datasets
from global_var import get_datasets
from tinydb import Query
def dataset_manage_page():
@ -7,7 +7,7 @@ def dataset_manage_page():
gr.Markdown("## 数据集管理")
with gr.Row():
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in datasets.all()]
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
@ -33,7 +33,7 @@ def dataset_manage_page():
# 从数据库获取数据集
Dataset = Query()
ds = datasets.get(Dataset.name == dataset_name)
ds = get_datasets().get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}

View File

@ -6,7 +6,7 @@ from unsloth import FastLanguageModel
import torch
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import model,tokenizer
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name
def model_manage_page():
@ -35,9 +35,8 @@ def model_manage_page():
def load_model(selected_model, max_seq_length, load_in_4bit):
try:
global model, tokenizer
# 判空操作,如果模型已加载,则先卸载
if model is not None:
if get_model() is not None:
unload_model()
model_path = os.path.join(models_dir, selected_model)
@ -46,6 +45,8 @@ def model_manage_page():
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载"
except Exception as e:
return f"加载模型时出错: {str(e)}"
@ -54,20 +55,17 @@ def model_manage_page():
def unload_model():
try:
global model, tokenizer
# 将模型移动到 CPU
model = get_model()
if model is not None:
model.cpu()
# 如果提供了 tokenizer也将其设置为 None
if tokenizer is not None:
tokenizer = None
# 清空 CUDA 缓存
torch.cuda.empty_cache()
# 将模型设置为 None
model = None
# 将模型和tokenizer设置为 None
set_model(None)
set_tokenizer(None)
return "当前未加载模型"
except Exception as e:

View File

@ -3,13 +3,13 @@ import sys
from pathlib import Path
from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import prompt_store
from global_var import get_prompt_store
from schema.prompt import promptTempleta
def prompt_manage_page():
def get_prompts() -> List[List[str]]:
selected_row = None
try:
db = prompt_store
db = get_prompt_store()
prompts = db.all()
return [
[p["id"], p["name"], p["description"], p["content"]]
@ -20,7 +20,7 @@ def prompt_manage_page():
def add_prompt(name, description, content):
try:
db = prompt_store
db = get_prompt_store()
new_prompt = promptTempleta(
name=name if name else "",
description=description if description else "",
@ -38,7 +38,7 @@ def prompt_manage_page():
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
db = prompt_store
db = get_prompt_store()
db.update({
"name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "",

View File

@ -2,13 +2,13 @@ import gradio as gr
from typing import List
from sqlmodel import Session, select
from schema import APIProvider
from global_var import sql_engine
from global_var import get_sql_engine
def setting_page():
def get_providers() -> List[List[str]]:
selected_row = None
try: # 添加异常处理
with Session(sql_engine) as session:
with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all()
return [
[p.id, p.model_id, p.base_url, p.api_key or ""]
@ -19,7 +19,7 @@ def setting_page():
def add_provider(model_id, base_url, api_key):
try:
with Session(sql_engine) as session:
with Session(get_sql_engine()) as session:
new_provider = APIProvider(
model_id=model_id if model_id else None,
base_url=base_url if base_url else None,
@ -37,7 +37,7 @@ def setting_page():
if not selected_row:
raise gr.Error("请先选择要编辑的行")
try:
with Session(sql_engine) as session:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")
@ -56,7 +56,7 @@ def setting_page():
if not selected_row:
raise gr.Error("请先选择要删除的行")
try:
with Session(sql_engine) as session:
with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0])
if not provider:
raise gr.Error("找不到选中的记录")

View File

@ -1,10 +1,64 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir")
docs = scan_docs_directory("workdir")
datasets = get_all_dataset("workdir")
_prompt_store = None
_sql_engine = None
_docs = None
_datasets = None
model = None
tokenizer = None
def get_prompt_store():
global _prompt_store
if _prompt_store is None:
_prompt_store = get_prompt_tinydb("workdir")
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
global _sql_engine
if _sql_engine is None:
_sql_engine = get_sqlite_engine("workdir")
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
global _docs
if _docs is None:
_docs = scan_docs_directory("workdir")
return _docs
def set_docs(new_docs):
global _docs
_docs = new_docs
def get_datasets():
global _datasets
if _datasets is None:
_datasets = get_all_dataset("workdir")
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
_model = None
_tokenizer = None
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer