Compare commits
4 Commits
e16882953d
...
90fde639ff
Author | SHA1 | Date | |
---|---|---|---|
![]() |
90fde639ff | ||
![]() |
5fc90903fb | ||
![]() |
81c2ad4a2d | ||
![]() |
314434951d |
@ -98,6 +98,11 @@ def dataset_generate_page():
|
|||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
||||||
|
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||||
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||||
|
with Session(get_sql_engine()) as session:
|
||||||
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
|
|
||||||
variables_dict = {}
|
variables_dict = {}
|
||||||
# 正确遍历DataFrame的行数据
|
# 正确遍历DataFrame的行数据
|
||||||
for _, row in variables_dataframe.iterrows():
|
for _, row in variables_dataframe.iterrows():
|
||||||
@ -112,9 +117,6 @@ def dataset_generate_page():
|
|||||||
# 模拟每个步骤的工作负载
|
# 模拟每个步骤的工作负载
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
# 更新进度条
|
|
||||||
# 第一个参数是当前的进度比例 (0.0 到 1.0)
|
|
||||||
# desc 参数可以动态更新进度条旁边的描述文字
|
|
||||||
current_progress = (i + 1) / total_steps
|
current_progress = (i + 1) / total_steps
|
||||||
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
|
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .dataset import *
|
from .dataset import *
|
||||||
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
|
from .dataset_generation import *
|
||||||
from .md_doc import MarkdownNode
|
from .md_doc import MarkdownNode
|
||||||
from .prompt import promptTempleta
|
from .prompt import promptTempleta
|
@ -12,40 +12,36 @@ class APIProvider(SQLModel, table=True):
|
|||||||
description="记录创建时间"
|
description="记录创建时间"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class LLMParameters(SQLModel):
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
class TokensUsage(SQLModel):
|
||||||
|
prompt_tokens: int = Field(default=0, description="提示词使用的token数量")
|
||||||
|
completion_tokens: int = Field(default=0, description="完成部分使用的token数量")
|
||||||
|
prompt_cache_hit_tokens: Optional[int] = Field(default=None, description="缓存命中token数量")
|
||||||
|
prompt_cache_miss_tokens: Optional[int] = Field(default=None, description="缓存未命中token数量")
|
||||||
|
|
||||||
class LLMResponse(SQLModel):
|
class LLMResponse(SQLModel):
|
||||||
timestamp: datetime = Field(
|
timestamp: datetime = Field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
description="响应的时间戳"
|
description="响应的时间戳"
|
||||||
)
|
)
|
||||||
response_id: str = Field(..., description="响应的唯一ID")
|
response_id: str = Field(..., description="响应的唯一ID")
|
||||||
tokens_usage: dict = Field(default_factory=lambda: {
|
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"prompt_cache_hit_tokens": None,
|
|
||||||
"prompt_cache_miss_tokens": None
|
|
||||||
}, description="token使用信息")
|
|
||||||
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
||||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||||
llm_parameters: dict = Field(default_factory=lambda: {
|
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||||
"temperature": None,
|
|
||||||
"max_tokens": None,
|
|
||||||
"top_p": None,
|
|
||||||
"frequency_penalty": None,
|
|
||||||
"presence_penalty": None,
|
|
||||||
"seed": None
|
|
||||||
}, description="API的生成参数")
|
|
||||||
|
|
||||||
class LLMRequest(SQLModel):
|
class LLMRequest(SQLModel):
|
||||||
prompt: str = Field(..., description="发送给API的提示词")
|
prompt: str = Field(..., description="发送给API的提示词")
|
||||||
provider_id: int = Field(foreign_key="apiprovider.id")
|
api_provider: APIProvider = Field(..., description="API提供者的信息")
|
||||||
provider: APIProvider = Relationship()
|
|
||||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||||
total_tokens_usage: dict = Field(default_factory=lambda: {
|
total_tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"prompt_cache_hit_tokens": None,
|
|
||||||
"prompt_cache_miss_tokens": None
|
|
||||||
}, description="token使用信息")
|
|
||||||
|
110
tools/reasoning.py
Normal file
110
tools/reasoning.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import openai
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
|
||||||
|
|
||||||
|
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=llm_request.api_provider.api_key,
|
||||||
|
base_url=llm_request.api_provider.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = 0.0
|
||||||
|
total_tokens = TokensUsage()
|
||||||
|
|
||||||
|
for i in range(rounds):
|
||||||
|
try:
|
||||||
|
round_start = datetime.now(timezone.utc)
|
||||||
|
messages = [{"role": "user", "content": llm_request.prompt}]
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=llm_request.api_provider.model_id,
|
||||||
|
messages=messages,
|
||||||
|
temperature=llm_parameters.temperature if llm_parameters else None,
|
||||||
|
max_tokens=llm_parameters.max_tokens if llm_parameters else None,
|
||||||
|
top_p=llm_parameters.top_p if llm_parameters else None,
|
||||||
|
frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None,
|
||||||
|
presence_penalty=llm_parameters.presence_penalty if llm_parameters else None,
|
||||||
|
seed=llm_parameters.seed if llm_parameters else None
|
||||||
|
)
|
||||||
|
|
||||||
|
round_end = datetime.now(timezone.utc)
|
||||||
|
duration = (round_end - round_start).total_seconds()
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
# 处理可能不存在的缓存token字段
|
||||||
|
usage = response.usage
|
||||||
|
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
|
||||||
|
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
|
||||||
|
|
||||||
|
tokens_usage = TokensUsage(
|
||||||
|
prompt_tokens=usage.prompt_tokens,
|
||||||
|
completion_tokens=usage.completion_tokens,
|
||||||
|
prompt_cache_hit_tokens=cache_hit,
|
||||||
|
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# 累加总token使用量
|
||||||
|
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
|
||||||
|
total_tokens.completion_tokens += tokens_usage.completion_tokens
|
||||||
|
if tokens_usage.prompt_cache_hit_tokens:
|
||||||
|
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
|
||||||
|
if tokens_usage.prompt_cache_miss_tokens:
|
||||||
|
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
|
||||||
|
|
||||||
|
llm_request.response.append(LLMResponse(
|
||||||
|
response_id=response.id,
|
||||||
|
tokens_usage=tokens_usage,
|
||||||
|
response_content={"content": response.choices[0].message.content},
|
||||||
|
total_duration=duration,
|
||||||
|
llm_parameters=llm_parameters
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
round_end = datetime.now(timezone.utc)
|
||||||
|
duration = (round_end - round_start).total_seconds()
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
llm_request.response.append(LLMResponse(
|
||||||
|
response_id=f"error-round-{i+1}",
|
||||||
|
response_content={"error": str(e)},
|
||||||
|
total_duration=duration
|
||||||
|
))
|
||||||
|
if llm_request.error is None:
|
||||||
|
llm_request.error = []
|
||||||
|
llm_request.error.append(str(e))
|
||||||
|
|
||||||
|
# 更新总耗时和总token使用量
|
||||||
|
llm_request.total_duration = total_duration
|
||||||
|
llm_request.total_tokens_usage = total_tokens
|
||||||
|
|
||||||
|
return llm_request
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from global_var import get_sql_engine, init_global_var
|
||||||
|
init_global_var("workdir")
|
||||||
|
api_state = "1 deepseek-chat"
|
||||||
|
with Session(get_sql_engine()) as session:
|
||||||
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
|
llm_request = LLMRequest(
|
||||||
|
prompt="你好,世界!",
|
||||||
|
api_provider=api_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# # 单次调用示例
|
||||||
|
# result = asyncio.run(call_openai_api(llm_request))
|
||||||
|
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
|
||||||
|
# for i, resp in enumerate(result.response, 1):
|
||||||
|
# print(f"响应{i}: {resp.response_content}")
|
||||||
|
|
||||||
|
# 多次调用示例
|
||||||
|
params = LLMParameters(temperature=0.7, max_tokens=100)
|
||||||
|
result = asyncio.run(call_openai_api(llm_request, 3,params))
|
||||||
|
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
||||||
|
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
||||||
|
for i, resp in enumerate(result.response, 1):
|
||||||
|
print(f"响应{i}: {resp.response_content}")
|
Loading…
x
Reference in New Issue
Block a user