refactor(schema): 重构数据模型以提高可维护性和可扩展性
- 新增 LLMParameters 类以统一处理 LLM 参数 - 新增 TokensUsage 类以统一处理 token 使用信息 - 更新 LLMResponse 和 LLMRequest 类,使用新的 LLMParameters 和 TokensUsage 类 - 优化数据模型结构,提高代码的可读性和可维护性
This commit is contained in:
parent
314434951d
commit
81c2ad4a2d
@ -1,4 +1,4 @@
|
||||
from .dataset import *
|
||||
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
|
||||
from .dataset_generation import *
|
||||
from .md_doc import MarkdownNode
|
||||
from .prompt import promptTempleta
|
@ -12,40 +12,36 @@ class APIProvider(SQLModel, table=True):
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
class LLMParameters(SQLModel):
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
class TokensUsage(SQLModel):
|
||||
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
|
||||
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
|
||||
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
|
||||
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
|
||||
|
||||
class LLMResponse(SQLModel):
|
||||
timestamp: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="响应的时间戳"
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: dict = Field(default_factory=lambda: {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"prompt_cache_hit_tokens": None,
|
||||
"prompt_cache_miss_tokens": None
|
||||
}, description="token使用信息")
|
||||
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: dict = Field(default_factory=lambda: {
|
||||
"temperature": None,
|
||||
"max_tokens": None,
|
||||
"top_p": None,
|
||||
"frequency_penalty": None,
|
||||
"presence_penalty": None,
|
||||
"seed": None
|
||||
}, description="API的生成参数")
|
||||
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||
|
||||
class LLMRequest(SQLModel):
|
||||
prompt: str = Field(..., description="发送给API的提示词")
|
||||
provider_id: int = Field(foreign_key="apiprovider.id")
|
||||
provider: APIProvider = Relationship()
|
||||
api_provider: APIProvider = Field(..., description="API提供者的信息")
|
||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
total_tokens_usage: dict = Field(default_factory=lambda: {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"prompt_cache_hit_tokens": None,
|
||||
"prompt_cache_miss_tokens": None
|
||||
}, description="token使用信息")
|
||||
total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||
|
Loading…
x
Reference in New Issue
Block a user