diff --git a/tools/reasoning.py b/tools/reasoning.py index 5b63291..a50a69f 100644 --- a/tools/reasoning.py +++ b/tools/reasoning.py @@ -7,56 +7,81 @@ from typing import Optional sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters -async def call_openai_api(llm_request: LLMRequest, llm_parameters: Optional[LLMParameters] = None) -> LLMResponse: +async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_parameters: Optional[LLMParameters] = None) -> LLMRequest: start_time = datetime.now(timezone.utc) client = openai.AsyncOpenAI( api_key=llm_request.api_provider.api_key, base_url=llm_request.api_provider.base_url ) - try: - messages = [{"role": "user", "content": llm_request.prompt}] - response = await client.chat.completions.create( - model=llm_request.api_provider.model_id, - messages=messages, - temperature=llm_parameters.temperature if llm_parameters else None, - max_tokens=llm_parameters.max_tokens if llm_parameters else None, - top_p=llm_parameters.top_p if llm_parameters else None, - frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None, - presence_penalty=llm_parameters.presence_penalty if llm_parameters else None, - seed=llm_parameters.seed if llm_parameters else None - ) - - end_time = datetime.now(timezone.utc) - duration = (end_time - start_time).total_seconds() - - # 处理可能不存在的缓存token字段 - usage = response.usage - cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None) - cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None) - - tokens_usage = TokensUsage( - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens, - prompt_cache_hit_tokens=cache_hit, - prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens - ) - - return LLMResponse( - response_id=response.id, - tokens_usage=tokens_usage, - response_content={"content": response.choices[0].message.content}, - total_duration=duration, - llm_parameters=llm_parameters - ) - except Exception as e: - end_time = datetime.now(timezone.utc) - duration = (end_time - start_time).total_seconds() - return LLMResponse( - response_id="error", - response_content={"error": str(e)}, - total_duration=duration - ) + total_duration = 0.0 + total_tokens = TokensUsage() + + for i in range(rounds): + try: + round_start = datetime.now(timezone.utc) + messages = [{"role": "user", "content": llm_request.prompt}] + response = await client.chat.completions.create( + model=llm_request.api_provider.model_id, + messages=messages, + temperature=llm_parameters.temperature if llm_parameters else None, + max_tokens=llm_parameters.max_tokens if llm_parameters else None, + top_p=llm_parameters.top_p if llm_parameters else None, + frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None, + presence_penalty=llm_parameters.presence_penalty if llm_parameters else None, + seed=llm_parameters.seed if llm_parameters else None + ) + + round_end = datetime.now(timezone.utc) + duration = (round_end - round_start).total_seconds() + total_duration += duration + + # 处理可能不存在的缓存token字段 + usage = response.usage + cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None) + cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None) + + tokens_usage = TokensUsage( + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + prompt_cache_hit_tokens=cache_hit, + prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens + ) + + # 累加总token使用量 + total_tokens.prompt_tokens += tokens_usage.prompt_tokens + total_tokens.completion_tokens += tokens_usage.completion_tokens + if tokens_usage.prompt_cache_hit_tokens: + total_tokens.prompt_cache_hit_tokens = (total_tokens.prompt_cache_hit_tokens or 0) + tokens_usage.prompt_cache_hit_tokens + if tokens_usage.prompt_cache_miss_tokens: + total_tokens.prompt_cache_miss_tokens = (total_tokens.prompt_cache_miss_tokens or 0) + tokens_usage.prompt_cache_miss_tokens + + llm_request.response.append(LLMResponse( + response_id=response.id, + tokens_usage=tokens_usage, + response_content={"content": response.choices[0].message.content}, + total_duration=duration, + llm_parameters=llm_parameters + )) + except Exception as e: + round_end = datetime.now(timezone.utc) + duration = (round_end - round_start).total_seconds() + total_duration += duration + + llm_request.response.append(LLMResponse( + response_id=f"error-round-{i+1}", + response_content={"error": str(e)}, + total_duration=duration + )) + if llm_request.error is None: + llm_request.error = [] + llm_request.error.append(str(e)) + + # 更新总耗时和总token使用量 + llm_request.total_duration = total_duration + llm_request.total_tokens_usage = total_tokens + + return llm_request if __name__ == "__main__": from sqlmodel import Session, select @@ -69,14 +94,17 @@ if __name__ == "__main__": prompt="你好,世界!", api_provider=api_provider ) - # 不使用LLM参数调用 - result = asyncio.run(call_openai_api(llm_request)) - print(f"\n不使用LLM参数调用结果: {result}") - # 使用LLM参数调用 - params = LLMParameters( - temperature=0.7, - max_tokens=100 - ) - result = asyncio.run(call_openai_api(llm_request, params)) - print(f"\nOpenAI API 响应: {result}") \ No newline at end of file + # # 单次调用示例 + # result = asyncio.run(call_openai_api(llm_request)) + # print(f"\n单次调用结果 - 响应数量: {len(result.response)}") + # for i, resp in enumerate(result.response, 1): + # print(f"响应{i}: {resp.response_content}") + + # 多次调用示例 + params = LLMParameters(temperature=0.7, max_tokens=100) + result = asyncio.run(call_openai_api(llm_request, 3,params)) + print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s") + print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}") + for i, resp in enumerate(result.response, 1): + print(f"响应{i}: {resp.response_content}") \ No newline at end of file