
- 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建 - 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型 - 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录
51 lines
2.3 KiB
Python
51 lines
2.3 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
from sqlmodel import SQLModel, Relationship, Field
|
|
|
|
class APIProvider(SQLModel, table=True):
|
|
id: Optional[int] = Field(default=None, primary_key=True)
|
|
base_url: str = Field(..., description="API的基础URL")
|
|
model_id: str = Field(..., description="API使用的模型ID")
|
|
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
|
created_at: datetime = Field(
|
|
default_factory=lambda: datetime.now(timezone.utc),
|
|
description="记录创建时间"
|
|
)
|
|
|
|
class LLMResponse(SQLModel):
|
|
timestamp: datetime = Field(
|
|
default_factory=lambda: datetime.now(timezone.utc),
|
|
description="响应的时间戳"
|
|
)
|
|
response_id: str = Field(..., description="响应的唯一ID")
|
|
tokens_usage: dict = Field(default_factory=lambda: {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"prompt_cache_hit_tokens": None,
|
|
"prompt_cache_miss_tokens": None
|
|
}, description="token使用信息")
|
|
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
|
llm_parameters: dict = Field(default_factory=lambda: {
|
|
"temperature": None,
|
|
"max_tokens": None,
|
|
"top_p": None,
|
|
"frequency_penalty": None,
|
|
"presence_penalty": None,
|
|
"seed": None
|
|
}, description="API的生成参数")
|
|
|
|
class LLMRequest(SQLModel):
|
|
prompt: str = Field(..., description="发送给API的提示词")
|
|
provider_id: int = Field(foreign_key="apiprovider.id")
|
|
provider: APIProvider = Relationship()
|
|
format: Optional[str] = Field(default=None, description="API响应的格式")
|
|
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
|
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
|
total_tokens_usage: dict = Field(default_factory=lambda: {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"prompt_cache_hit_tokens": None,
|
|
"prompt_cache_miss_tokens": None
|
|
}, description="token使用信息") |