Compare commits

..

6 Commits

Author SHA1 Message Date
carry
868fcd45ba refactor(project): 重构项目文件组织结构
- 修改模型管理和训练页面的导入路径
- 更新 main.py 中的导入模块
- 调整 tools 包的内容,移除 model 模块
- 新建 train 包,包含 model 模块
- 优化 __init__.py 文件,简化导入语句
2025-04-19 21:49:19 +08:00
carry
5a21c8598a feat(tools): 支持 OpenAI API 的 JSON 格式返回结果
- 在 call_openai_api 函数中添加对 JSON 格式返回结果的支持
- 增加 llm_request.format 参数处理,将用户 prompt 与格式要求合并
- 添加 response_format 参数到 OpenAI API 请求
- 更新示例,使用 JSON 格式返回结果
2025-04-19 21:10:22 +08:00
carry
1e829c9268 feat(tools): 优化 JSON 示例生成函数
- 增加 include_optional 参数,决定是否包含可选字段
- 添加 list_length 参数,用于控制列表字段的示例长度
- 在列表示例中添加省略标记,更直观展示多元素列表
- 优化字典字段的示例生成逻辑
2025-04-19 21:07:00 +08:00
carry
9fc3ab904b feat(frontend): 实现了固定参数的注入 2025-04-19 17:48:45 +08:00
carry
d827f9758f fix(frontend): 修复dataframe_value返回值只有一列的bug 2025-04-19 17:30:10 +08:00
carry
ff1e9731bc fix(tools): 修复call_openai_api的导出 2025-04-19 17:13:19 +08:00
9 changed files with 58 additions and 30 deletions

View File

@ -6,6 +6,7 @@ from sqlmodel import Session, select
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import APIProvider
from tools import call_openai_api
from global_var import get_docs, get_prompt_store, get_sql_engine
def dataset_generate_page():
@ -34,6 +35,7 @@ def dataset_generate_page():
prompt_content = prompt_data["content"]
prompt_template = PromptTemplate.from_template(prompt_content)
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
initial_dataframe_value = [[var, ""] for var in input_variables]
prompt_dropdown = gr.Dropdown(
@ -95,22 +97,30 @@ def dataset_generate_page():
input_variables = prompt_template.input_variables
input_variables.remove("document_slice")
dataframe_value = [] if input_variables is None else input_variables
dataframe_value = [[var, ""] for var in input_variables]
return selected_prompt, dataframe_value
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
doc = [i for i in get_docs() if i.name == doc_state][0]
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
prompt = PromptTemplate.from_template(prompt["content"])
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
variables_dict = {}
# 正确遍历DataFrame的行数据
for _, row in variables_dataframe.iterrows():
var_name = row['变量名'].strip()
var_value = row['变量值'].strip()
if var_name:
variables_dict[var_name] = var_value
# 注入除document_slice以外的所有参数
prompt = prompt.partial(**variables_dict)
print(doc)
print(prompt.format(document_slice="test"))
print(variables_dict)
import time
total_steps = rounds
for i in range(total_steps):

View File

@ -7,7 +7,7 @@ import torch
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name
from train import get_model_name
def model_manage_page():
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹

View File

@ -8,7 +8,8 @@ from transformers import TrainerCallback
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
from tools import train_model, find_available_port
from tools import find_available_port
from train import train_model
def train_page():
with gr.Blocks() as demo:

View File

@ -1,5 +1,5 @@
import gradio as gr
import unsloth
import train
from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store

View File

@ -1,5 +1,5 @@
from .parse_markdown import parse_markdown
from .parse_markdown import *
from .document import *
from .json_example import generate_example_json
from .model import *
from .port import *
from .port import *
from .reasoning import call_openai_api

View File

@ -3,17 +3,19 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
import json
from datetime import datetime, date
def generate_example_json(model: type[BaseModel]) -> str:
def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
"""
根据 Pydantic V2 模型生成示例 JSON 数据结构
"""
def _generate_example(field_type: Any) -> Any:
origin = get_origin(field_type)
args = get_args(field_type)
if origin is list or origin is List:
return [_generate_example(args[0])] if args else []
# 生成多个元素,这里生成 3 个
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
result.append("......")
return result
elif origin is dict or origin is Dict:
if len(args) == 2:
return {"key": _generate_example(args[1])}
@ -35,21 +37,22 @@ def generate_example_json(model: type[BaseModel]) -> str:
elif field_type is date:
return date.today().isoformat()
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return json.loads(generate_example_json(field_type))
return json.loads(generate_example_json(field_type, include_optional))
else:
# 处理直接类型注解(非泛型)
if field_type is type(None):
return None
try:
if issubclass(field_type, BaseModel):
return json.loads(generate_example_json(field_type))
return json.loads(generate_example_json(field_type, include_optional))
except TypeError:
pass
return "unknown"
example_data = {}
for field_name, field in model.model_fields.items():
example_data[field_name] = _generate_example(field.annotation)
if include_optional or not isinstance(field.default, type(None)):
example_data[field_name] = _generate_example(field.annotation)
return json.dumps(example_data, indent=2, default=str)

View File

@ -16,21 +16,30 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
total_duration = 0.0
total_tokens = TokensUsage()
prompt = llm_request.prompt
round_start = datetime.now(timezone.utc)
if llm_request.format:
prompt += "\n请以JSON格式返回结果" + llm_request.format
for i in range(rounds):
round_start = datetime.now(timezone.utc)
try:
round_start = datetime.now(timezone.utc)
messages = [{"role": "user", "content": llm_request.prompt}]
response = await client.chat.completions.create(
model=llm_request.api_provider.model_id,
messages=messages,
temperature=llm_parameters.temperature if llm_parameters else None,
max_tokens=llm_parameters.max_tokens if llm_parameters else None,
top_p=llm_parameters.top_p if llm_parameters else None,
frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None,
presence_penalty=llm_parameters.presence_penalty if llm_parameters else None,
seed=llm_parameters.seed if llm_parameters else None
)
messages = [{"role": "user", "content": prompt}]
create_args = {
"model": llm_request.api_provider.model_id,
"messages": messages,
"temperature": llm_parameters.temperature if llm_parameters else None,
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
"top_p": llm_parameters.top_p if llm_parameters else None,
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
"seed": llm_parameters.seed if llm_parameters else None
} # 处理format参数
if llm_request.format:
create_args["response_format"] = {"type": "json_object"}
response = await client.chat.completions.create(**create_args)
round_end = datetime.now(timezone.utc)
duration = (round_end - round_start).total_seconds()
@ -84,15 +93,19 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
return llm_request
if __name__ == "__main__":
from json_example import generate_example_json
from sqlmodel import Session, select
from global_var import get_sql_engine, init_global_var
from schema import dataset_item
init_global_var("workdir")
api_state = "1 deepseek-chat"
with Session(get_sql_engine()) as session:
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
llm_request = LLMRequest(
prompt="你好,世界!",
api_provider=api_provider
prompt="测试,随便说点什么",
api_provider=api_provider,
format=generate_example_json(dataset_item)
)
# # 单次调用示例

1
train/__init__.py Normal file
View File

@ -0,0 +1 @@
from .model import *