From 81c2ad4a2d328e2fb5e4523b533139d76a9a46d2 Mon Sep 17 00:00:00 2001 From: carry Date: Sat, 19 Apr 2025 16:39:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor(schema):=20=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=A8=A1=E5=9E=8B=E4=BB=A5=E6=8F=90=E9=AB=98=E5=8F=AF?= =?UTF-8?q?=E7=BB=B4=E6=8A=A4=E6=80=A7=E5=92=8C=E5=8F=AF=E6=89=A9=E5=B1=95?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 LLMParameters 类以统一处理 LLM 参数 - 新增 TokensUsage 类以统一处理 token 使用信息 - 更新 LLMResponse 和 LLMRequest 类,使用新的 LLMParameters 和 TokensUsage 类 - 优化数据模型结构,提高代码的可读性和可维护性 --- schema/__init__.py | 2 +- schema/dataset_generation.py | 40 ++++++++++++++++-------------------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/schema/__init__.py b/schema/__init__.py index e8119ee..bb4c465 100644 --- a/schema/__init__.py +++ b/schema/__init__.py @@ -1,4 +1,4 @@ from .dataset import * -from .dataset_generation import APIProvider, LLMResponse, LLMRequest +from .dataset_generation import * from .md_doc import MarkdownNode from .prompt import promptTempleta \ No newline at end of file diff --git a/schema/dataset_generation.py b/schema/dataset_generation.py index 3a43423..2940f2d 100644 --- a/schema/dataset_generation.py +++ b/schema/dataset_generation.py @@ -12,40 +12,36 @@ class APIProvider(SQLModel, table=True): description="记录创建时间" ) +class LLMParameters(SQLModel): + temperature: Optional[float] = None + max_tokens: Optional[int] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + seed: Optional[int] = None + +class TokensUsage(SQLModel): + prompt_tokens: int = Field(default=0, description="提示词使用的token数量") + completion_tokens: int = Field(default=0, description="完成部分使用的token数量") + prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量") + prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量") + class LLMResponse(SQLModel): timestamp: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), description="响应的时间戳" ) response_id: str = Field(..., description="响应的唯一ID") - tokens_usage: dict = Field(default_factory=lambda: { - "prompt_tokens": 0, - "completion_tokens": 0, - "prompt_cache_hit_tokens": None, - "prompt_cache_miss_tokens": None - }, description="token使用信息") + tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息") response_content: dict = Field(default_factory=dict, description="API响应的内容") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") - llm_parameters: dict = Field(default_factory=lambda: { - "temperature": None, - "max_tokens": None, - "top_p": None, - "frequency_penalty": None, - "presence_penalty": None, - "seed": None - }, description="API的生成参数") + llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数") class LLMRequest(SQLModel): prompt: str = Field(..., description="发送给API的提示词") - provider_id: int = Field(foreign_key="apiprovider.id") - provider: APIProvider = Relationship() + api_provider: APIProvider = Field(..., description="API提供者的信息") format: Optional[str] = Field(default=None, description="API响应的格式") response: list[LLMResponse] = Field(default_factory=list, description="API响应列表") error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息") total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒") - total_tokens_usage: dict = Field(default_factory=lambda: { - "prompt_tokens": 0, - "completion_tokens": 0, - "prompt_cache_hit_tokens": None, - "prompt_cache_miss_tokens": None - }, description="token使用信息") \ No newline at end of file + total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")