From f65593674190787bf8084c6b2a408a175fa0bd95 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Fri, 11 Apr 2025 18:08:16 +0800 Subject: [PATCH] =?UTF-8?q?refactor(global=5Fvar):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E5=8F=98=E9=87=8F=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 init_global_var 函数,用于统一初始化所有全局变量 - 修改 get_prompt_store、get_sql_engine、get_docs 和 get_datasets 函数,使用新的全局变量初始化逻辑 - 更新 main.py 中的代码,使用新的 init_global_var 函数替代原有的单独初始化方法 --- global_var.py | 20 ++++++++------------ main.py | 9 +++++---- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/global_var.py b/global_var.py index 53566d5..69d353a 100644 --- a/global_var.py +++ b/global_var.py @@ -6,10 +6,15 @@ _sql_engine = None _docs = None _datasets = None +def init_global_var(workdir="workdir"): + """Initialize all global variables""" + global _prompt_store, _sql_engine, _docs, _datasets + _prompt_store = get_prompt_tinydb(workdir) + _sql_engine = get_sqlite_engine(workdir) + _docs = scan_docs_directory(workdir) + _datasets = get_all_dataset(workdir) + def get_prompt_store(): - global _prompt_store - if _prompt_store is None: - _prompt_store = get_prompt_tinydb("workdir") return _prompt_store def set_prompt_store(new_prompt_store): @@ -17,9 +22,6 @@ def set_prompt_store(new_prompt_store): _prompt_store = new_prompt_store def get_sql_engine(): - global _sql_engine - if _sql_engine is None: - _sql_engine = get_sqlite_engine("workdir") return _sql_engine def set_sql_engine(new_sql_engine): @@ -27,9 +29,6 @@ def set_sql_engine(new_sql_engine): _sql_engine = new_sql_engine def get_docs(): - global _docs - if _docs is None: - _docs = scan_docs_directory("workdir") return _docs def set_docs(new_docs): @@ -37,9 +36,6 @@ def set_docs(new_docs): _docs = new_docs def get_datasets(): - global _datasets - if _datasets is None: - _datasets = get_all_dataset("workdir") return _datasets def set_datasets(new_datasets): diff --git a/main.py b/main.py index 3797c1e..1158201 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,13 @@ import gradio as gr from frontend.setting_page import setting_page from frontend import * -from db import initialize_sqlite_db,initialize_prompt_store -from global_var import sql_engine,prompt_store +from db import initialize_sqlite_db, initialize_prompt_store +from global_var import init_global_var, get_sql_engine, get_prompt_store if __name__ == "__main__": - initialize_sqlite_db(sql_engine) - initialize_prompt_store(prompt_store) + init_global_var() + initialize_sqlite_db(get_sql_engine()) + initialize_prompt_store(get_prompt_store()) with gr.Blocks() as app: gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") with gr.Tabs():