from sqlmodel import SQLModel, create_engine, Session from sqlmodel import select from typing import Optional import os from pathlib import Path import sys from dotenv import load_dotenv from sqlalchemy.engine import Engine # 项目路径配置 sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema.dataset_generation import APIProvider # 全局引擎实例(可选) _engine: Optional[Engine] = None def get_engine(workdir: str) -> Engine: global _engine if not _engine: db_dir = os.path.join(workdir, "db") os.makedirs(db_dir, exist_ok=True) db_path = os.path.join(db_dir, "db.sqlite") db_url = f"sqlite:///{db_path}" _engine = create_engine(db_url) return _engine def initialize_db(engine: Engine) -> None: SQLModel.metadata.create_all(engine) load_dotenv() api_key = os.getenv("API_KEY") base_url = os.getenv("BASE_URL") model_id = os.getenv("MODEL_ID") if api_key and base_url and model_id: with Session(engine) as session: # 使用新的 select() 语法查询 statement = select(APIProvider).limit(1) existing_provider = session.exec(statement).first() if not existing_provider: provider = APIProvider( base_url=base_url, model_id=model_id, api_key=api_key ) session.add(provider) session.commit() if __name__ == "__main__": workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") engine = get_engine(workdir) initialize_db(engine)