From 5a21c8598a8b5675233daf2ebd17ee02d49c60f9 Mon Sep 17 00:00:00 2001 From: carry Date: Sat, 19 Apr 2025 21:10:22 +0800 Subject: [PATCH] =?UTF-8?q?feat(tools):=20=E6=94=AF=E6=8C=81=20OpenAI=20AP?= =?UTF-8?q?I=20=E7=9A=84=20JSON=20=E6=A0=BC=E5=BC=8F=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 call_openai_api 函数中添加对 JSON 格式返回结果的支持 - 增加 llm_request.format 参数处理,将用户 prompt 与格式要求合并 - 添加 response_format 参数到 OpenAI API 请求 - 更新示例,使用 JSON 格式返回结果 --- tools/reasoning.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tools/reasoning.py b/tools/reasoning.py index a50a69f..27cc718 100644 --- a/tools/reasoning.py +++ b/tools/reasoning.py @@ -16,21 +16,30 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete total_duration = 0.0 total_tokens = TokensUsage() - + prompt = llm_request.prompt + round_start = datetime.now(timezone.utc) + if llm_request.format: + prompt += "\n请以JSON格式返回结果" + llm_request.format + for i in range(rounds): + round_start = datetime.now(timezone.utc) try: - round_start = datetime.now(timezone.utc) - messages = [{"role": "user", "content": llm_request.prompt}] - response = await client.chat.completions.create( - model=llm_request.api_provider.model_id, - messages=messages, - temperature=llm_parameters.temperature if llm_parameters else None, - max_tokens=llm_parameters.max_tokens if llm_parameters else None, - top_p=llm_parameters.top_p if llm_parameters else None, - frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None, - presence_penalty=llm_parameters.presence_penalty if llm_parameters else None, - seed=llm_parameters.seed if llm_parameters else None - ) + messages = [{"role": "user", "content": prompt}] + create_args = { + "model": llm_request.api_provider.model_id, + "messages": messages, + "temperature": llm_parameters.temperature if llm_parameters else None, + "max_tokens": llm_parameters.max_tokens if llm_parameters else None, + "top_p": llm_parameters.top_p if llm_parameters else None, + "frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None, + "presence_penalty": llm_parameters.presence_penalty if llm_parameters else None, + "seed": llm_parameters.seed if llm_parameters else None + } # 处理format参数 + + if llm_request.format: + create_args["response_format"] = {"type": "json_object"} + + response = await client.chat.completions.create(**create_args) round_end = datetime.now(timezone.utc) duration = (round_end - round_start).total_seconds() @@ -84,15 +93,19 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete return llm_request if __name__ == "__main__": + from json_example import generate_example_json from sqlmodel import Session, select from global_var import get_sql_engine, init_global_var + from schema import dataset_item + init_global_var("workdir") api_state = "1 deepseek-chat" with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() llm_request = LLMRequest( - prompt="你好,世界!", - api_provider=api_provider + prompt="测试,随便说点什么", + api_provider=api_provider, + format=generate_example_json(dataset_item) ) # # 单次调用示例