From 46b4453ccd5f93166fd7955219eb07a5eae30cb9 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Tue, 8 Apr 2025 13:19:58 +0800 Subject: [PATCH] =?UTF-8?q?refactor(frontend):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除各前端页面中重复的数据库引擎初始化代码 - 在 global_var.py 中统一初始化和存储数据库引擎 - 更新 setting_page.py 和 main.py 中的数据库连接逻辑 - 优化代码结构,提高可维护性和可扩展性 --- frontend/setting_page.py | 7 +++---- global_var.py | 3 +++ main.py | 11 +++-------- 3 files changed, 9 insertions(+), 12 deletions(-) create mode 100644 global_var.py diff --git a/frontend/setting_page.py b/frontend/setting_page.py index 0789d3e..7b4867e 100644 --- a/frontend/setting_page.py +++ b/frontend/setting_page.py @@ -4,13 +4,12 @@ from sqlmodel import Session, select from db import get_engine from schema import APIProvider import os - -engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir")) +from global_var import sql_engine def setting_page(): def get_providers() -> List[List[str]]: try: # 添加异常处理 - with Session(engine) as session: + with Session(sql_engine) as session: providers = session.exec(select(APIProvider)).all() return [ [p.id, p.model_id, p.base_url, p.api_key or ""] @@ -21,7 +20,7 @@ def setting_page(): def add_provider(model_id, base_url, api_key): try: - with Session(engine) as session: + with Session(sql_engine) as session: new_provider = APIProvider( model_id=model_id if model_id else None, base_url=base_url if base_url else None, diff --git a/global_var.py b/global_var.py new file mode 100644 index 0000000..7a494b3 --- /dev/null +++ b/global_var.py @@ -0,0 +1,3 @@ +from db import get_engine + +sql_engine = get_engine("workdir") \ No newline at end of file diff --git a/main.py b/main.py index 8f8dc88..ac64130 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,11 @@ import gradio as gr from frontend.setting_page import setting_page from frontend import chat_page,setting_page,train_page,dataset_page -from db import initialize_db as init_db,get_engine +from db import initialize_db +from global_var import sql_engine if __name__ == "__main__": - init_db(get_engine("workdir")) - - pages = [] - pages.append(setting_page()) - pages.append(chat_page()) - pages.append(train_page()) - pages.append(dataset_page()) + initialize_db(sql_engine) with gr.Blocks() as app: gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")