From 5fc90903fb951fb9907eaddce6376a6f47b0d8cd Mon Sep 17 00:00:00 2001 From: carry Date: Sat, 19 Apr 2025 16:53:48 +0800 Subject: [PATCH] =?UTF-8?q?feat(tools):=20=E6=B7=BB=E5=8A=A0=20reasoning.p?= =?UTF-8?q?y=20=E5=B7=A5=E5=85=B7=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 reasoning.py 文件,实现与 OpenAI API 的交互 - 添加 call_openai_api 函数,用于发送请求并处理响应 - 支持可选的 LLMParameters 参数,以定制化请求 - 处理 API 响应中的 tokens 使用情况 - 提供错误处理和缓存 token 字段的处理 --- tools/reasoning.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tools/reasoning.py diff --git a/tools/reasoning.py b/tools/reasoning.py new file mode 100644 index 0000000..5b63291 --- /dev/null +++ b/tools/reasoning.py @@ -0,0 +1,82 @@ +import sys +import asyncio +import openai +from pathlib import Path +from datetime import datetime, timezone +from typing import Optional +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from schema import APIProvider, LLMRequest, LLMResponse, TokensUsage, LLMParameters + +async def call_openai_api(llm_request: LLMRequest, llm_parameters: Optional[LLMParameters] = None) -> LLMResponse: + start_time = datetime.now(timezone.utc) + client = openai.AsyncOpenAI( + api_key=llm_request.api_provider.api_key, + base_url=llm_request.api_provider.base_url + ) + + try: + messages = [{"role": "user", "content": llm_request.prompt}] + response = await client.chat.completions.create( + model=llm_request.api_provider.model_id, + messages=messages, + temperature=llm_parameters.temperature if llm_parameters else None, + max_tokens=llm_parameters.max_tokens if llm_parameters else None, + top_p=llm_parameters.top_p if llm_parameters else None, + frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None, + presence_penalty=llm_parameters.presence_penalty if llm_parameters else None, + seed=llm_parameters.seed if llm_parameters else None + ) + + end_time = datetime.now(timezone.utc) + duration = (end_time - start_time).total_seconds() + + # 处理可能不存在的缓存token字段 + usage = response.usage + cache_hit = getattr(usage, 'prompt_cache_hit_tokens', None) + cache_miss = getattr(usage, 'prompt_cache_miss_tokens', None) + + tokens_usage = TokensUsage( + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + prompt_cache_hit_tokens=cache_hit, + prompt_cache_miss_tokens=cache_miss if cache_miss is not None else usage.prompt_tokens + ) + + return LLMResponse( + response_id=response.id, + tokens_usage=tokens_usage, + response_content={"content": response.choices[0].message.content}, + total_duration=duration, + llm_parameters=llm_parameters + ) + except Exception as e: + end_time = datetime.now(timezone.utc) + duration = (end_time - start_time).total_seconds() + return LLMResponse( + response_id="error", + response_content={"error": str(e)}, + total_duration=duration + ) + +if __name__ == "__main__": + from sqlmodel import Session, select + from global_var import get_sql_engine, init_global_var + init_global_var("workdir") + api_state = "1 deepseek-chat" + with Session(get_sql_engine()) as session: + api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() + llm_request = LLMRequest( + prompt="你好,世界!", + api_provider=api_provider + ) + # 不使用LLM参数调用 + result = asyncio.run(call_openai_api(llm_request)) + print(f"\n不使用LLM参数调用结果: {result}") + + # 使用LLM参数调用 + params = LLMParameters( + temperature=0.7, + max_tokens=100 + ) + result = asyncio.run(call_openai_api(llm_request, params)) + print(f"\nOpenAI API 响应: {result}") \ No newline at end of file