diff --git a/frontend/setting_page.py b/frontend/setting_page.py index 36c03bb..2739930 100644 --- a/frontend/setting_page.py +++ b/frontend/setting_page.py @@ -1,9 +1,35 @@ import gradio as gr +from typing import List, Dict +from sqlmodel import Session, select +from db import get_engine +from schema import APIProvider +import os + +# 获取数据库引擎 +engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir")) def setting_page(): + def get_providers() -> List[List[str]]: + with Session(engine) as session: + providers = session.exec(select(APIProvider)).all() + return [ + [p.id, p.model_id, p.base_url, p.api_key or ""] + for p in providers + ] + with gr.Blocks() as demo: - gr.Markdown("## 设置") + gr.Markdown("## API Provider 管理") + with gr.Row(): - with gr.Column(): - pass + # API Provider列表 + with gr.Column(scale=2): + provider_table = gr.DataFrame( + headers=["id" , "model id", "URL", "API Key"], + datatype=["number","str", "str", "str"], + interactive=True, + value=get_providers(), + wrap=True, + col_count=(4, "fixed") + ) + return demo \ No newline at end of file diff --git a/main.py b/main.py index 0a915d2..5a93227 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,11 @@ import gradio as gr from frontend.setting_page import setting_page from frontend.chat_page import chat_page from frontend.train_page import train_page -def main(): +from db import initialize_db as init_db,get_engine + +if __name__ == "__main__": + init_db(get_engine("workdir")) + setting_demo = setting_page() chat_demo = chat_page() train_demo = train_page() @@ -17,7 +21,4 @@ def main(): with gr.TabItem("设置"): setting_demo.render() - app.launch() - -if __name__ == "__main__": - main() \ No newline at end of file + app.launch() \ No newline at end of file