
- 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建 - 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型 - 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
from sqlmodel import SQLModel, create_engine, Session
|
|
from sqlmodel import select
|
|
from typing import Optional
|
|
import os
|
|
from pathlib import Path
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
from sqlalchemy.engine import Engine
|
|
|
|
# 项目路径配置
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
from schema.dataset_generation import APIProvider
|
|
|
|
# 全局引擎实例(可选)
|
|
_engine: Optional[Engine] = None
|
|
|
|
def get_engine(workdir: str) -> Engine:
|
|
global _engine
|
|
if not _engine:
|
|
db_dir = os.path.join(workdir, "db")
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
db_path = os.path.join(db_dir, "db.sqlite")
|
|
db_url = f"sqlite:///{db_path}"
|
|
_engine = create_engine(db_url)
|
|
return _engine
|
|
|
|
def initialize_db(engine: Engine) -> None:
|
|
SQLModel.metadata.create_all(engine)
|
|
load_dotenv()
|
|
api_key = os.getenv("API_KEY")
|
|
base_url = os.getenv("BASE_URL")
|
|
model_id = os.getenv("MODEL_ID")
|
|
|
|
if api_key and base_url and model_id:
|
|
with Session(engine) as session:
|
|
# 使用新的 select() 语法查询
|
|
statement = select(APIProvider).limit(1)
|
|
existing_provider = session.exec(statement).first()
|
|
|
|
if not existing_provider:
|
|
provider = APIProvider(
|
|
base_url=base_url,
|
|
model_id=model_id,
|
|
api_key=api_key
|
|
)
|
|
session.add(provider)
|
|
session.commit()
|
|
|
|
if __name__ == "__main__":
|
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
|
engine = get_engine(workdir)
|
|
initialize_db(engine) |