diff --git a/global_var.py b/global_var.py index 53566d5..69d353a 100644 --- a/global_var.py +++ b/global_var.py @@ -6,10 +6,15 @@ _sql_engine = None _docs = None _datasets = None +def init_global_var(workdir="workdir"): + """Initialize all global variables""" + global _prompt_store, _sql_engine, _docs, _datasets + _prompt_store = get_prompt_tinydb(workdir) + _sql_engine = get_sqlite_engine(workdir) + _docs = scan_docs_directory(workdir) + _datasets = get_all_dataset(workdir) + def get_prompt_store(): - global _prompt_store - if _prompt_store is None: - _prompt_store = get_prompt_tinydb("workdir") return _prompt_store def set_prompt_store(new_prompt_store): @@ -17,9 +22,6 @@ def set_prompt_store(new_prompt_store): _prompt_store = new_prompt_store def get_sql_engine(): - global _sql_engine - if _sql_engine is None: - _sql_engine = get_sqlite_engine("workdir") return _sql_engine def set_sql_engine(new_sql_engine): @@ -27,9 +29,6 @@ def set_sql_engine(new_sql_engine): _sql_engine = new_sql_engine def get_docs(): - global _docs - if _docs is None: - _docs = scan_docs_directory("workdir") return _docs def set_docs(new_docs): @@ -37,9 +36,6 @@ def set_docs(new_docs): _docs = new_docs def get_datasets(): - global _datasets - if _datasets is None: - _datasets = get_all_dataset("workdir") return _datasets def set_datasets(new_datasets): diff --git a/main.py b/main.py index 3797c1e..1158201 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,13 @@ import gradio as gr from frontend.setting_page import setting_page from frontend import * -from db import initialize_sqlite_db,initialize_prompt_store -from global_var import sql_engine,prompt_store +from db import initialize_sqlite_db, initialize_prompt_store +from global_var import init_global_var, get_sql_engine, get_prompt_store if __name__ == "__main__": - initialize_sqlite_db(sql_engine) - initialize_prompt_store(prompt_store) + init_global_var() + initialize_sqlite_db(get_sql_engine()) + initialize_prompt_store(get_prompt_store()) with gr.Blocks() as app: gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") with gr.Tabs():