From b1e98ca913ffa0c9ddd65a0709c0fc0e4c0059e1 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Sun, 6 Apr 2025 19:59:23 +0800 Subject: [PATCH] =?UTF-8?q?feat(db):=20=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=B9=B6=E5=88=9B=E5=BB=BA=20APIPro?= =?UTF-8?q?vider=20=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建 - 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型 - 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录 --- db/init_db.py | 52 ++++++++++++++++++++++++++++++++++++ schema/dataset_generation.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 db/init_db.py create mode 100644 schema/dataset_generation.py diff --git a/db/init_db.py b/db/init_db.py new file mode 100644 index 0000000..02dc0a0 --- /dev/null +++ b/db/init_db.py @@ -0,0 +1,52 @@ +from sqlmodel import SQLModel, create_engine, Session +from sqlmodel import select +from typing import Optional +import os +from pathlib import Path +import sys +from dotenv import load_dotenv +from sqlalchemy.engine import Engine + +# 项目路径配置 +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from schema.dataset_generation import APIProvider + +# 全局引擎实例(可选) +_engine: Optional[Engine] = None + +def get_engine(workdir: str) -> Engine: + global _engine + if not _engine: + db_dir = os.path.join(workdir, "db") + os.makedirs(db_dir, exist_ok=True) + db_path = os.path.join(db_dir, "db.sqlite") + db_url = f"sqlite:///{db_path}" + _engine = create_engine(db_url) + return _engine + +def initialize_db(engine: Engine) -> None: + SQLModel.metadata.create_all(engine) + load_dotenv() + api_key = os.getenv("API_KEY") + base_url = os.getenv("BASE_URL") + model_id = os.getenv("MODEL_ID") + + if api_key and base_url and model_id: + with Session(engine) as session: + # 使用新的 select() 语法查询 + statement = select(APIProvider).limit(1) + existing_provider = session.exec(statement).first() + + if not existing_provider: + provider = APIProvider( + base_url=base_url, + model_id=model_id, + api_key=api_key + ) + session.add(provider) + session.commit() + +if __name__ == "__main__": + workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") + engine = get_engine(workdir) + initialize_db(engine) \ No newline at end of file diff --git a/schema/dataset_generation.py b/schema/dataset_generation.py new file mode 100644 index 0000000..30eb354 --- /dev/null +++ b/schema/dataset_generation.py @@ -0,0 +1,51 @@ +from datetime import datetime, timezone +from typing import Optional +from sqlmodel import SQLModel, Relationship, Field + +class APIProvider(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + base_url: str = Field(..., description="API的基础URL") + model_id: str = Field(..., description="API使用的模型ID") + api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥") + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="记录创建时间" + ) + +class LLMResponse(SQLModel): + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="响应的时间戳" + ) + response_id: str = Field(..., description="响应的唯一ID") + tokens_usage: dict = Field(default_factory=lambda: { + "prompt_tokens": 0, + "completion_tokens": 0, + "prompt_cache_hit_tokens": None, + "prompt_cache_miss_tokens": None + }, description="token使用信息") + response_content: dict = Field(default_factory=dict, description="API响应的内容") + total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") + llm_parameters: dict = Field(default_factory=lambda: { + "temperature": None, + "max_tokens": None, + "top_p": None, + "frequency_penalty": None, + "presence_penalty": None, + "seed": None + }, description="API的生成参数") + +class LLMRequest(SQLModel): + prompt: str = Field(..., description="发送给API的提示词") + provider_id: int = Field(foreign_key="apiprovider.id") + provider: APIProvider = Relationship() + format: Optional[str] = Field(default=None, description="API响应的格式") + response: list[LLMResponse] = Field(default_factory=list, description="API响应列表") + error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息") + total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") + total_tokens_usage: dict = Field(default_factory=lambda: { + "prompt_tokens": 0, + "completion_tokens": 0, + "prompt_cache_hit_tokens": None, + "prompt_cache_miss_tokens": None + }, description="token使用信息") \ No newline at end of file