refactor(frontend): 重构数据库连接方式

- 移除各前端页面中重复的数据库引擎初始化代码
- 在 global_var.py 中统一初始化和存储数据库引擎
- 更新 setting_page.py 和 main.py 中的数据库连接逻辑
- 优化代码结构,提高可维护性和可扩展性
This commit is contained in:
carry 2025-04-08 13:19:58 +08:00
parent d5b528d375
commit 46b4453ccd
3 changed files with 9 additions and 12 deletions

View File

@ -4,13 +4,12 @@ from sqlmodel import Session, select
from db import get_engine from db import get_engine
from schema import APIProvider from schema import APIProvider
import os import os
from global_var import sql_engine
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]: def get_providers() -> List[List[str]]:
try: # 添加异常处理 try: # 添加异常处理
with Session(engine) as session: with Session(sql_engine) as session:
providers = session.exec(select(APIProvider)).all() providers = session.exec(select(APIProvider)).all()
return [ return [
[p.id, p.model_id, p.base_url, p.api_key or ""] [p.id, p.model_id, p.base_url, p.api_key or ""]
@ -21,7 +20,7 @@ def setting_page():
def add_provider(model_id, base_url, api_key): def add_provider(model_id, base_url, api_key):
try: try:
with Session(engine) as session: with Session(sql_engine) as session:
new_provider = APIProvider( new_provider = APIProvider(
model_id=model_id if model_id else None, model_id=model_id if model_id else None,
base_url=base_url if base_url else None, base_url=base_url if base_url else None,

3
global_var.py Normal file
View File

@ -0,0 +1,3 @@
from db import get_engine
sql_engine = get_engine("workdir")

11
main.py
View File

@ -1,16 +1,11 @@
import gradio as gr import gradio as gr
from frontend.setting_page import setting_page from frontend.setting_page import setting_page
from frontend import chat_page,setting_page,train_page,dataset_page from frontend import chat_page,setting_page,train_page,dataset_page
from db import initialize_db as init_db,get_engine from db import initialize_db
from global_var import sql_engine
if __name__ == "__main__": if __name__ == "__main__":
init_db(get_engine("workdir")) initialize_db(sql_engine)
pages = []
pages.append(setting_page())
pages.append(chat_page())
pages.append(train_page())
pages.append(dataset_page())
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")