diff --git a/frontend/setting_page.py b/frontend/setting_page.py index 0789d3e..7b4867e 100644 --- a/frontend/setting_page.py +++ b/frontend/setting_page.py @@ -4,13 +4,12 @@ from sqlmodel import Session, select from db import get_engine from schema import APIProvider import os - -engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir")) +from global_var import sql_engine def setting_page(): def get_providers() -> List[List[str]]: try: # 添加异常处理 - with Session(engine) as session: + with Session(sql_engine) as session: providers = session.exec(select(APIProvider)).all() return [ [p.id, p.model_id, p.base_url, p.api_key or ""] @@ -21,7 +20,7 @@ def setting_page(): def add_provider(model_id, base_url, api_key): try: - with Session(engine) as session: + with Session(sql_engine) as session: new_provider = APIProvider( model_id=model_id if model_id else None, base_url=base_url if base_url else None, diff --git a/global_var.py b/global_var.py new file mode 100644 index 0000000..7a494b3 --- /dev/null +++ b/global_var.py @@ -0,0 +1,3 @@ +from db import get_engine + +sql_engine = get_engine("workdir") \ No newline at end of file diff --git a/main.py b/main.py index 8f8dc88..ac64130 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,11 @@ import gradio as gr from frontend.setting_page import setting_page from frontend import chat_page,setting_page,train_page,dataset_page -from db import initialize_db as init_db,get_engine +from db import initialize_db +from global_var import sql_engine if __name__ == "__main__": - init_db(get_engine("workdir")) - - pages = [] - pages.append(setting_page()) - pages.append(chat_page()) - pages.append(train_page()) - pages.append(dataset_page()) + initialize_db(sql_engine) with gr.Blocks() as app: gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")