feat(tools): 增加 OpenAI API 多轮调用功能

- 在 call_openai_api 函数中添加 rounds 参数,支持多次调用
- 累加每次调用的耗时和 token 使用情况
- 将多次调用的结果存储在 LLMRequest 对象的 response 列表中
- 更新函数返回类型,返回包含多次调用信息的 LLMRequest 对象
- 优化错误处理,记录每轮调用的错误信息
This commit is contained in:
carry 2025-04-19 17:02:00 +08:00
parent 5fc90903fb
commit 90fde639ff

View File

@ -7,56 +7,81 @@ from typing import Optional
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
async def call_openai_api(llm_request: LLMRequest, llm_parameters: Optional[LLMParameters] = None) -> LLMResponse: async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
api_key=llm_request.api_provider.api_key, api_key=llm_request.api_provider.api_key,
base_url=llm_request.api_provider.base_url base_url=llm_request.api_provider.base_url
) )
try: total_duration = 0.0
messages = [{"role": "user", "content": llm_request.prompt}] total_tokens = TokensUsage()
response = await client.chat.completions.create(
model=llm_request.api_provider.model_id, for i in range(rounds):
messages=messages, try:
temperature=llm_parameters.temperature if llm_parameters else None, round_start = datetime.now(timezone.utc)
max_tokens=llm_parameters.max_tokens if llm_parameters else None, messages = [{"role": "user", "content": llm_request.prompt}]
top_p=llm_parameters.top_p if llm_parameters else None, response = await client.chat.completions.create(
frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None, model=llm_request.api_provider.model_id,
presence_penalty=llm_parameters.presence_penalty if llm_parameters else None, messages=messages,
seed=llm_parameters.seed if llm_parameters else None temperature=llm_parameters.temperature if llm_parameters else None,
) max_tokens=llm_parameters.max_tokens if llm_parameters else None,
top_p=llm_parameters.top_p if llm_parameters else None,
end_time = datetime.now(timezone.utc) frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None,
duration = (end_time - start_time).total_seconds() presence_penalty=llm_parameters.presence_penalty if llm_parameters else None,
seed=llm_parameters.seed if llm_parameters else None
# 处理可能不存在的缓存token字段 )
usage = response.usage
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None) round_end = datetime.now(timezone.utc)
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None) duration = (round_end - round_start).total_seconds()
total_duration += duration
tokens_usage = TokensUsage(
prompt_tokens=usage.prompt_tokens, # 处理可能不存在的缓存token字段
completion_tokens=usage.completion_tokens, usage = response.usage
prompt_cache_hit_tokens=cache_hit, cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
)
tokens_usage = TokensUsage(
return LLMResponse( prompt_tokens=usage.prompt_tokens,
response_id=response.id, completion_tokens=usage.completion_tokens,
tokens_usage=tokens_usage, prompt_cache_hit_tokens=cache_hit,
response_content={"content": response.choices[0].message.content}, prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
total_duration=duration, )
llm_parameters=llm_parameters
) # 累加总token使用量
except Exception as e: total_tokens.prompt_tokens += tokens_usage.prompt_tokens
end_time = datetime.now(timezone.utc) total_tokens.completion_tokens += tokens_usage.completion_tokens
duration = (end_time - start_time).total_seconds() if tokens_usage.prompt_cache_hit_tokens:
return LLMResponse( total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
response_id="error", if tokens_usage.prompt_cache_miss_tokens:
response_content={"error": str(e)}, total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
total_duration=duration
) llm_request.response.append(LLMResponse(
response_id=response.id,
tokens_usage=tokens_usage,
response_content={"content": response.choices[0].message.content},
total_duration=duration,
llm_parameters=llm_parameters
))
except Exception as e:
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
total_duration += duration
llm_request.response.append(LLMResponse(
response_id=f"error-round-{i+1}",
response_content={"error": str(e)},
total_duration=duration
))
if llm_request.error is None:
llm_request.error = []
llm_request.error.append(str(e))
# 更新总耗时和总token使用量
llm_request.total_duration = total_duration
llm_request.total_tokens_usage = total_tokens
return llm_request
if __name__ == "__main__": if __name__ == "__main__":
from sqlmodel import Session, select from sqlmodel import Session, select
@ -69,14 +94,17 @@ if __name__ == "__main__":
prompt="你好,世界!", prompt="你好,世界!",
api_provider=api_provider api_provider=api_provider
) )
# 不使用LLM参数调用
result = asyncio.run(call_openai_api(llm_request))
print(f"\n不使用LLM参数调用结果: {result}")
# 使用LLM参数调用 # # 单次调用示例
params = LLMParameters( # result = asyncio.run(call_openai_api(llm_request))
temperature=0.7, # print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
max_tokens=100 # for i, resp in enumerate(result.response, 1):
) # print(f"响应{i}: {resp.response_content}")
result = asyncio.run(call_openai_api(llm_request, params))
print(f"\nOpenAI API 响应: {result}") # 多次调用示例
params = LLMParameters(temperature=0.7, max_tokens=100)
result = asyncio.run(call_openai_api(llm_request, 3,params))
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
for i, resp in enumerate(result.response, 1):
print(f"响应{i}: {resp.response_content}")