gzhu-biyesheji/schema/dataset_generation.py
carry 967133162e refactor(schema): 在 APIProvider 模型中设置 id 字段为不可变
- 在 APIProvider 类中,将 id 字段的定义更新,添加 allow_mutation=False 参数
- 这个改动确保了主键字段在创建后不可更改,提高了数据的一致性和安全性
2025-04-08 16:02:46 +08:00

51 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True,allow_mutation=False)
base_url: str = Field(...,min_length=1,description="API的基础URL不能为空")
model_id: str = Field(...,min_length=1,description="API使用的模型ID不能为空")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")