import os import sys from sqlmodel import SQLModel, create_engine, Session from sqlmodel import select from typing import Optional from pathlib import Path from dotenv import load_dotenv from sqlalchemy.engine import Engine # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema.dataset_generation import APIProvider # 全局变量,用于存储数据库引擎实例 _engine: Optional[Engine] = None def load_sqlite_engine(workdir: str) -> Engine: """ 获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。 Args: workdir (str): 工作目录路径,用于确定数据库文件的存储位置。 Returns: Engine: SQLAlchemy 数据库引擎实例。 """ global _engine if not _engine: # 创建数据库目录(如果不存在) db_dir = os.path.join(workdir, "db") os.makedirs(db_dir, exist_ok=True) # 定义数据库文件路径 db_path = os.path.join(db_dir, "db.sqlite") # 创建数据库URL db_url = f"sqlite:///{db_path}" # 创建数据库引擎 _engine = create_engine(db_url) return _engine def initialize_sqlite_db(engine: Engine) -> None: """ 初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。 Args: engine (Engine): SQLAlchemy 数据库引擎实例。 """ # 创建所有表结构 SQLModel.metadata.create_all(engine) # 加载环境变量 load_dotenv() # 从环境变量中获取API相关配置 api_key = os.getenv("API_KEY") base_url = os.getenv("BASE_URL") model_id = os.getenv("MODEL_ID") # 如果所有必要的环境变量都存在,则插入初始数据 if api_key and base_url and model_id: with Session(engine) as session: # 查询是否已存在APIProvider记录 statement = select(APIProvider).limit(1) existing_provider = session.exec(statement).first() # 如果不存在,则插入新的APIProvider记录 if not existing_provider: provider = APIProvider( base_url=base_url, model_id=model_id, api_key=api_key ) session.add(provider) session.commit() if __name__ == "__main__": # 定义工作目录路径 workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") # 获取数据库引擎 engine = load_sqlite_engine(workdir) # 初始化数据库 initialize_sqlite_db(engine)