diff --git a/db/init_db.py b/db/init_db.py new file mode 100644 index 0000000..02dc0a0 --- /dev/null +++ b/db/init_db.py @@ -0,0 +1,52 @@ +from sqlmodel import SQLModel, create_engine, Session +from sqlmodel import select +from typing import Optional +import os +from pathlib import Path +import sys +from dotenv import load_dotenv +from sqlalchemy.engine import Engine + +# 项目路径配置 +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from schema.dataset_generation import APIProvider + +# 全局引擎实例(可选) +_engine: Optional[Engine] = None + +def get_engine(workdir: str) -> Engine: + global _engine + if not _engine: + db_dir = os.path.join(workdir, "db") + os.makedirs(db_dir, exist_ok=True) + db_path = os.path.join(db_dir, "db.sqlite") + db_url = f"sqlite:///{db_path}" + _engine = create_engine(db_url) + return _engine + +def initialize_db(engine: Engine) -> None: + SQLModel.metadata.create_all(engine) + load_dotenv() + api_key = os.getenv("API_KEY") + base_url = os.getenv("BASE_URL") + model_id = os.getenv("MODEL_ID") + + if api_key and base_url and model_id: + with Session(engine) as session: + # 使用新的 select() 语法查询 + statement = select(APIProvider).limit(1) + existing_provider = session.exec(statement).first() + + if not existing_provider: + provider = APIProvider( + base_url=base_url, + model_id=model_id, + api_key=api_key + ) + session.add(provider) + session.commit() + +if __name__ == "__main__": + workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") + engine = get_engine(workdir) + initialize_db(engine) \ No newline at end of file diff --git a/schema/dataset_generation.py b/schema/dataset_generation.py new file mode 100644 index 0000000..30eb354 --- /dev/null +++ b/schema/dataset_generation.py @@ -0,0 +1,51 @@ +from datetime import datetime, timezone +from typing import Optional +from sqlmodel import SQLModel, Relationship, Field + +class APIProvider(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + base_url: str = Field(..., description="API的基础URL") + model_id: str = Field(..., description="API使用的模型ID") + api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥") + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="记录创建时间" + ) + +class LLMResponse(SQLModel): + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="响应的时间戳" + ) + response_id: str = Field(..., description="响应的唯一ID") + tokens_usage: dict = Field(default_factory=lambda: { + "prompt_tokens": 0, + "completion_tokens": 0, + "prompt_cache_hit_tokens": None, + "prompt_cache_miss_tokens": None + }, description="token使用信息") + response_content: dict = Field(default_factory=dict, description="API响应的内容") + total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") + llm_parameters: dict = Field(default_factory=lambda: { + "temperature": None, + "max_tokens": None, + "top_p": None, + "frequency_penalty": None, + "presence_penalty": None, + "seed": None + }, description="API的生成参数") + +class LLMRequest(SQLModel): + prompt: str = Field(..., description="发送给API的提示词") + provider_id: int = Field(foreign_key="apiprovider.id") + provider: APIProvider = Relationship() + format: Optional[str] = Field(default=None, description="API响应的格式") + response: list[LLMResponse] = Field(default_factory=list, description="API响应列表") + error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息") + total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") + total_tokens_usage: dict = Field(default_factory=lambda: { + "prompt_tokens": 0, + "completion_tokens": 0, + "prompt_cache_hit_tokens": None, + "prompt_cache_miss_tokens": None + }, description="token使用信息") \ No newline at end of file