import sys import asyncio import openai from pathlib import Path from datetime import datetime, timezone from typing import Optional sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest: start_time = datetime.now(timezone.utc) client = openai.AsyncOpenAI( api_key=llm_request.api_provider.api_key, base_url=llm_request.api_provider.base_url ) total_duration = 0.0 total_tokens = TokensUsage() prompt = llm_request.prompt round_start = datetime.now(timezone.utc) if llm_request.format: prompt += "\n请以JSON格式返回结果" + llm_request.format for i in range(rounds): round_start = datetime.now(timezone.utc) try: messages = [{"role": "user", "content": prompt}] create_args = { "model": llm_request.api_provider.model_id, "messages": messages, "temperature": llm_parameters.temperature if llm_parameters else None, "max_tokens": llm_parameters.max_tokens if llm_parameters else None, "top_p": llm_parameters.top_p if llm_parameters else None, "frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None, "presence_penalty": llm_parameters.presence_penalty if llm_parameters else None, "seed": llm_parameters.seed if llm_parameters else None } # 处理format参数 if llm_request.format: create_args["response_format"] = {"type": "json_object"} response = await client.chat.completions.create(**create_args) round_end = datetime.now(timezone.utc) duration = (round_end - round_start).total_seconds() total_duration += duration # 处理可能不存在的缓存token字段 usage = response.usage cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None) cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None) tokens_usage = TokensUsage( prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, prompt_cache_hit_tokens=cache_hit, prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens ) # 累加总token使用量 total_tokens.prompt_tokens += tokens_usage.prompt_tokens total_tokens.completion_tokens += tokens_usage.completion_tokens if tokens_usage.prompt_cache_hit_tokens: total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens if tokens_usage.prompt_cache_miss_tokens: total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens llm_request.response.append(LLMResponse( response_id=response.id, tokens_usage=tokens_usage, content = response.choices[0].message.content, total_duration=duration, llm_parameters=llm_parameters )) except Exception as e: round_end = datetime.now(timezone.utc) duration = (round_end - round_start).total_seconds() total_duration += duration llm_request.response.append(LLMResponse( response_id=f"error-round-{i+1}", content={"error": str(e)}, total_duration=duration )) if llm_request.error is None: llm_request.error = [] llm_request.error.append(str(e)) # 更新总耗时和总token使用量 llm_request.total_duration = total_duration llm_request.total_tokens_usage = total_tokens return llm_request if __name__ == "__main__": from json_example import generate_json_example from sqlmodel import Session, select from global_var import get_sql_engine, init_global_var from schema import DatasetItem init_global_var("workdir") api_state = "1 deepseek-chat" with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() llm_request = LLMRequest( prompt="测试,随便说点什么", api_provider=api_provider, format=generate_json_example(DatasetItem) ) # # 单次调用示例 # result = asyncio.run(call_openai_api(llm_request)) # print(f"\n单次调用结果 - 响应数量: {len(result.response)}") # for i, resp in enumerate(result.response, 1): # print(f"响应{i}: {resp.response_content}") # 多次调用示例 params = LLMParameters(temperature=0.7, max_tokens=100) result = asyncio.run(call_openai_api(llm_request, 3,params)) print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s") print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}") for i, resp in enumerate(result.response, 1): print(f"响应{i}: {resp.content}")