diff --git a/tools/reasoning.py b/tools/reasoning.py index a50a69f..27cc718 100644 --- a/tools/reasoning.py +++ b/tools/reasoning.py @@ -16,21 +16,30 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete total_duration = 0.0 total_tokens = TokensUsage() - + prompt = llm_request.prompt + round_start = datetime.now(timezone.utc) + if llm_request.format: + prompt += "\n请以JSON格式返回结果" + llm_request.format + for i in range(rounds): + round_start = datetime.now(timezone.utc) try: - round_start = datetime.now(timezone.utc) - messages = [{"role": "user", "content": llm_request.prompt}] - response = await client.chat.completions.create( - model=llm_request.api_provider.model_id, - messages=messages, - temperature=llm_parameters.temperature if llm_parameters else None, - max_tokens=llm_parameters.max_tokens if llm_parameters else None, - top_p=llm_parameters.top_p if llm_parameters else None, - frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None, - presence_penalty=llm_parameters.presence_penalty if llm_parameters else None, - seed=llm_parameters.seed if llm_parameters else None - ) + messages = [{"role": "user", "content": prompt}] + create_args = { + "model": llm_request.api_provider.model_id, + "messages": messages, + "temperature": llm_parameters.temperature if llm_parameters else None, + "max_tokens": llm_parameters.max_tokens if llm_parameters else None, + "top_p": llm_parameters.top_p if llm_parameters else None, + "frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None, + "presence_penalty": llm_parameters.presence_penalty if llm_parameters else None, + "seed": llm_parameters.seed if llm_parameters else None + } # 处理format参数 + + if llm_request.format: + create_args["response_format"] = {"type": "json_object"} + + response = await client.chat.completions.create(**create_args) round_end = datetime.now(timezone.utc) duration = (round_end - round_start).total_seconds() @@ -84,15 +93,19 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete return llm_request if __name__ == "__main__": + from json_example import generate_example_json from sqlmodel import Session, select from global_var import get_sql_engine, init_global_var + from schema import dataset_item + init_global_var("workdir") api_state = "1 deepseek-chat" with Session(get_sql_engine()) as session: api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first() llm_request = LLMRequest( - prompt="你好,世界!", - api_provider=api_provider + prompt="测试,随便说点什么", + api_provider=api_provider, + format=generate_example_json(dataset_item) ) # # 单次调用示例