feat(db): 初始化数据库并创建 APIProvider 表
- 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建 - 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型 - 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录
This commit is contained in:
parent
2d5a5277ae
commit
b1e98ca913
52
db/init_db.py
Normal file
52
db/init_db.py
Normal file
@ -0,0 +1,52 @@
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from sqlmodel import select
|
||||
from typing import Optional
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# 项目路径配置
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset_generation import APIProvider
|
||||
|
||||
# 全局引擎实例(可选)
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def get_engine(workdir: str) -> Engine:
|
||||
global _engine
|
||||
if not _engine:
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
db_path = os.path.join(db_dir, "db.sqlite")
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
def initialize_db(engine: Engine) -> None:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
load_dotenv()
|
||||
api_key = os.getenv("API_KEY")
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model_id = os.getenv("MODEL_ID")
|
||||
|
||||
if api_key and base_url and model_id:
|
||||
with Session(engine) as session:
|
||||
# 使用新的 select() 语法查询
|
||||
statement = select(APIProvider).limit(1)
|
||||
existing_provider = session.exec(statement).first()
|
||||
|
||||
if not existing_provider:
|
||||
provider = APIProvider(
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
api_key=api_key
|
||||
)
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
engine = get_engine(workdir)
|
||||
initialize_db(engine)
|
51
schema/dataset_generation.py
Normal file
51
schema/dataset_generation.py
Normal file
@ -0,0 +1,51 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Relationship, Field
|
||||
|
||||
class APIProvider(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
base_url: str = Field(..., description="API的基础URL")
|
||||
model_id: str = Field(..., description="API使用的模型ID")
|
||||
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
class LLMResponse(SQLModel):
|
||||
timestamp: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="响应的时间戳"
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: dict = Field(default_factory=lambda: {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"prompt_cache_hit_tokens": None,
|
||||
"prompt_cache_miss_tokens": None
|
||||
}, description="token使用信息")
|
||||
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: dict = Field(default_factory=lambda: {
|
||||
"temperature": None,
|
||||
"max_tokens": None,
|
||||
"top_p": None,
|
||||
"frequency_penalty": None,
|
||||
"presence_penalty": None,
|
||||
"seed": None
|
||||
}, description="API的生成参数")
|
||||
|
||||
class LLMRequest(SQLModel):
|
||||
prompt: str = Field(..., description="发送给API的提示词")
|
||||
provider_id: int = Field(foreign_key="apiprovider.id")
|
||||
provider: APIProvider = Relationship()
|
||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
total_tokens_usage: dict = Field(default_factory=lambda: {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"prompt_cache_hit_tokens": None,
|
||||
"prompt_cache_miss_tokens": None
|
||||
}, description="token使用信息")
|
Loading…
x
Reference in New Issue
Block a user