gzhu-biyesheji/tools/reasoning.py
carry 5fc3b4950b refactor(schema): 修改 LLMResponse 中 API 响应内容的字段名称
- 将 LLMResponse 类中的 response_content 字段重命名为 content
- 更新字段类型从 dict 改为 str,以更准确地表示响应内容
- 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
2025-04-20 18:40:51 +08:00

123 lines
5.4 KiB
Python

import sys
import asyncio
import openai
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
start_time = datetime.now(timezone.utc)
client = openai.AsyncOpenAI(
api_key=llm_request.api_provider.api_key,
base_url=llm_request.api_provider.base_url
)
total_duration = 0.0
total_tokens = TokensUsage()
prompt = llm_request.prompt
round_start = datetime.now(timezone.utc)
if llm_request.format:
prompt += "\n请以JSON格式返回结果" + llm_request.format
for i in range(rounds):
round_start = datetime.now(timezone.utc)
try:
messages = [{"role": "user", "content": prompt}]
create_args = {
"model": llm_request.api_provider.model_id,
"messages": messages,
"temperature": llm_parameters.temperature if llm_parameters else None,
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
"top_p": llm_parameters.top_p if llm_parameters else None,
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
"seed": llm_parameters.seed if llm_parameters else None
} # 处理format参数
if llm_request.format:
create_args["response_format"] = {"type": "json_object"}
response = await client.chat.completions.create(**create_args)
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
# 处理可能不存在的缓存token字段
usage = response.usage
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
tokens_usage = TokensUsage(
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
prompt_cache_hit_tokens=cache_hit,
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
)
# 累加总token使用量
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
total_tokens.completion_tokens += tokens_usage.completion_tokens
if tokens_usage.prompt_cache_hit_tokens:
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
if tokens_usage.prompt_cache_miss_tokens:
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
llm_request.response.append(LLMResponse(
response_id=response.id,
tokens_usage=tokens_usage,
content = response.choices[0].message.content,
total_duration=duration,
llm_parameters=llm_parameters
))
except Exception as e:
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}",
content={"error": str(e)},
total_duration=duration
))
if llm_request.error is None:
llm_request.error = []
llm_request.error.append(str(e))
# 更新总耗时和总token使用量
llm_request.total_duration = total_duration
llm_request.total_tokens_usage = total_tokens
return llm_request
if __name__ == "__main__":
from json_example import generate_json_example
from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var
from schema import DatasetItem
init_global_var("workdir")
api_state = "1 deepseek-chat"
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
llm_request = LLMRequest(
prompt="测试,随便说点什么",
api_provider=api_provider,
format=generate_json_example(DatasetItem)
)
# # 单次调用示例
# result = asyncio.run(call_openai_api(llm_request))
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
# for i, resp in enumerate(result.response, 1):
# print(f"响应{i}: {resp.response_content}")
# 多次调用示例
params = LLMParameters(temperature=0.7, max_tokens=100)
result = asyncio.run(call_openai_api(llm_request, 3,params))
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.content}")