
- 将 LLMResponse 类中的 response_content 字段重命名为 content - 更新字段类型从 dict 改为 str,以更准确地表示响应内容 - 在 reasoning.py 中相应地修改了调用 LLMResponse 时的参数
123 lines
5.4 KiB
Python
123 lines
5.4 KiB
Python
import sys
|
|
import asyncio
|
|
import openai
|
|
from pathlib import Path
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters
|
|
|
|
async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest:
|
|
start_time = datetime.now(timezone.utc)
|
|
client = openai.AsyncOpenAI(
|
|
api_key=llm_request.api_provider.api_key,
|
|
base_url=llm_request.api_provider.base_url
|
|
)
|
|
|
|
total_duration = 0.0
|
|
total_tokens = TokensUsage()
|
|
prompt = llm_request.prompt
|
|
round_start = datetime.now(timezone.utc)
|
|
if llm_request.format:
|
|
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
|
|
|
for i in range(rounds):
|
|
round_start = datetime.now(timezone.utc)
|
|
try:
|
|
messages = [{"role": "user", "content": prompt}]
|
|
create_args = {
|
|
"model": llm_request.api_provider.model_id,
|
|
"messages": messages,
|
|
"temperature": llm_parameters.temperature if llm_parameters else None,
|
|
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
|
"top_p": llm_parameters.top_p if llm_parameters else None,
|
|
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
|
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
|
"seed": llm_parameters.seed if llm_parameters else None
|
|
} # 处理format参数
|
|
|
|
if llm_request.format:
|
|
create_args["response_format"] = {"type": "json_object"}
|
|
|
|
response = await client.chat.completions.create(**create_args)
|
|
|
|
round_end = datetime.now(timezone.utc)
|
|
duration = (round_end - round_start).total_seconds()
|
|
total_duration += duration
|
|
|
|
# 处理可能不存在的缓存token字段
|
|
usage = response.usage
|
|
cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None)
|
|
cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None)
|
|
|
|
tokens_usage = TokensUsage(
|
|
prompt_tokens=usage.prompt_tokens,
|
|
completion_tokens=usage.completion_tokens,
|
|
prompt_cache_hit_tokens=cache_hit,
|
|
prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens
|
|
)
|
|
|
|
# 累加总token使用量
|
|
total_tokens.prompt_tokens += tokens_usage.prompt_tokens
|
|
total_tokens.completion_tokens += tokens_usage.completion_tokens
|
|
if tokens_usage.prompt_cache_hit_tokens:
|
|
total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens
|
|
if tokens_usage.prompt_cache_miss_tokens:
|
|
total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens
|
|
|
|
llm_request.response.append(LLMResponse(
|
|
response_id=response.id,
|
|
tokens_usage=tokens_usage,
|
|
content = response.choices[0].message.content,
|
|
total_duration=duration,
|
|
llm_parameters=llm_parameters
|
|
))
|
|
except Exception as e:
|
|
round_end = datetime.now(timezone.utc)
|
|
duration = (round_end - round_start).total_seconds()
|
|
total_duration += duration
|
|
|
|
llm_request.response.append(LLMResponse(
|
|
response_id=f"error-round-{i+1}",
|
|
content={"error": str(e)},
|
|
total_duration=duration
|
|
))
|
|
if llm_request.error is None:
|
|
llm_request.error = []
|
|
llm_request.error.append(str(e))
|
|
|
|
# 更新总耗时和总token使用量
|
|
llm_request.total_duration = total_duration
|
|
llm_request.total_tokens_usage = total_tokens
|
|
|
|
return llm_request
|
|
|
|
if __name__ == "__main__":
|
|
from json_example import generate_json_example
|
|
from sqlmodel import Session, select
|
|
from global_var import get_sql_engine, init_global_var
|
|
from schema import DatasetItem
|
|
|
|
init_global_var("workdir")
|
|
api_state = "1 deepseek-chat"
|
|
with Session(get_sql_engine()) as session:
|
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
|
llm_request = LLMRequest(
|
|
prompt="测试,随便说点什么",
|
|
api_provider=api_provider,
|
|
format=generate_json_example(DatasetItem)
|
|
)
|
|
|
|
# # 单次调用示例
|
|
# result = asyncio.run(call_openai_api(llm_request))
|
|
# print(f"\n单次调用结果 - 响应数量: {len(result.response)}")
|
|
# for i, resp in enumerate(result.response, 1):
|
|
# print(f"响应{i}: {resp.response_content}")
|
|
|
|
# 多次调用示例
|
|
params = LLMParameters(temperature=0.7, max_tokens=100)
|
|
result = asyncio.run(call_openai_api(llm_request, 3,params))
|
|
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
|
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
|
for i, resp in enumerate(result.response, 1):
|
|
print(f"响应{i}: {resp.content}") |