Compare commits
6 Commits
90fde639ff
...
868fcd45ba
Author | SHA1 | Date | |
---|---|---|---|
![]() |
868fcd45ba | ||
![]() |
5a21c8598a | ||
![]() |
1e829c9268 | ||
![]() |
9fc3ab904b | ||
![]() |
d827f9758f | ||
![]() |
ff1e9731bc |
@ -6,6 +6,7 @@ from sqlmodel import Session, select
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider
|
||||
from tools import call_openai_api
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||
|
||||
def dataset_generate_page():
|
||||
@ -34,6 +35,7 @@ def dataset_generate_page():
|
||||
prompt_content = prompt_data["content"]
|
||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||
|
||||
prompt_dropdown = gr.Dropdown(
|
||||
@ -95,22 +97,30 @@ def dataset_generate_page():
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
dataframe_value = [] if input_variables is None else input_variables
|
||||
dataframe_value = [[var, ""] for var in input_variables]
|
||||
return selected_prompt, dataframe_value
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||
prompt = PromptTemplate.from_template(prompt["content"])
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
|
||||
variables_dict = {}
|
||||
# 正确遍历DataFrame的行数据
|
||||
for _, row in variables_dataframe.iterrows():
|
||||
var_name = row['变量名'].strip()
|
||||
var_value = row['变量值'].strip()
|
||||
if var_name:
|
||||
variables_dict[var_name] = var_value
|
||||
|
||||
# 注入除document_slice以外的所有参数
|
||||
prompt = prompt.partial(**variables_dict)
|
||||
|
||||
print(doc)
|
||||
print(prompt.format(document_slice="test"))
|
||||
print(variables_dict)
|
||||
|
||||
import time
|
||||
total_steps = rounds
|
||||
for i in range(total_steps):
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||
from tools.model import get_model_name
|
||||
from train import get_model_name
|
||||
|
||||
def model_manage_page():
|
||||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||
|
@ -8,7 +8,8 @@ from transformers import TrainerCallback
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||
from tools import train_model, find_available_port
|
||||
from tools import find_available_port
|
||||
from train import train_model
|
||||
|
||||
def train_page():
|
||||
with gr.Blocks() as demo:
|
||||
|
2
main.py
2
main.py
@ -1,5 +1,5 @@
|
||||
import gradio as gr
|
||||
import unsloth
|
||||
import train
|
||||
from frontend import *
|
||||
from db import initialize_sqlite_db, initialize_prompt_store
|
||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .parse_markdown import parse_markdown
|
||||
from .parse_markdown import *
|
||||
from .document import *
|
||||
from .json_example import generate_example_json
|
||||
from .model import *
|
||||
from .port import *
|
||||
from .reasoning import call_openai_api
|
@ -3,17 +3,19 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
def generate_example_json(model: type[BaseModel]) -> str:
|
||||
def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||
"""
|
||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||
"""
|
||||
|
||||
def _generate_example(field_type: Any) -> Any:
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is list or origin is List:
|
||||
return [_generate_example(args[0])] if args else []
|
||||
# 生成多个元素,这里生成 3 个
|
||||
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
|
||||
result.append("......")
|
||||
return result
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2:
|
||||
return {"key": _generate_example(args[1])}
|
||||
@ -35,21 +37,22 @@ def generate_example_json(model: type[BaseModel]) -> str:
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_example_json(field_type))
|
||||
return json.loads(generate_example_json(field_type, include_optional))
|
||||
else:
|
||||
# 处理直接类型注解(非泛型)
|
||||
if field_type is type(None):
|
||||
return None
|
||||
try:
|
||||
if issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_example_json(field_type))
|
||||
return json.loads(generate_example_json(field_type, include_optional))
|
||||
except TypeError:
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
example_data = {}
|
||||
for field_name, field in model.model_fields.items():
|
||||
example_data[field_name] = _generate_example(field.annotation)
|
||||
if include_optional or not isinstance(field.default, type(None)):
|
||||
example_data[field_name] = _generate_example(field.annotation)
|
||||
|
||||
return json.dumps(example_data, indent=2, default=str)
|
||||
|
||||
|
@ -16,21 +16,30 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
||||
|
||||
total_duration = 0.0
|
||||
total_tokens = TokensUsage()
|
||||
prompt = llm_request.prompt
|
||||
round_start = datetime.now(timezone.utc)
|
||||
if llm_request.format:
|
||||
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
||||
|
||||
for i in range(rounds):
|
||||
round_start = datetime.now(timezone.utc)
|
||||
try:
|
||||
round_start = datetime.now(timezone.utc)
|
||||
messages = [{"role": "user", "content": llm_request.prompt}]
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_request.api_provider.model_id,
|
||||
messages=messages,
|
||||
temperature=llm_parameters.temperature if llm_parameters else None,
|
||||
max_tokens=llm_parameters.max_tokens if llm_parameters else None,
|
||||
top_p=llm_parameters.top_p if llm_parameters else None,
|
||||
frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None,
|
||||
presence_penalty=llm_parameters.presence_penalty if llm_parameters else None,
|
||||
seed=llm_parameters.seed if llm_parameters else None
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
create_args = {
|
||||
"model": llm_request.api_provider.model_id,
|
||||
"messages": messages,
|
||||
"temperature": llm_parameters.temperature if llm_parameters else None,
|
||||
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
||||
"top_p": llm_parameters.top_p if llm_parameters else None,
|
||||
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
||||
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
||||
"seed": llm_parameters.seed if llm_parameters else None
|
||||
} # 处理format参数
|
||||
|
||||
if llm_request.format:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
response = await client.chat.completions.create(**create_args)
|
||||
|
||||
round_end = datetime.now(timezone.utc)
|
||||
duration = (round_end - round_start).total_seconds()
|
||||
@ -84,15 +93,19 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
||||
return llm_request
|
||||
|
||||
if __name__ == "__main__":
|
||||
from json_example import generate_example_json
|
||||
from sqlmodel import Session, select
|
||||
from global_var import get_sql_engine, init_global_var
|
||||
from schema import dataset_item
|
||||
|
||||
init_global_var("workdir")
|
||||
api_state = "1 deepseek-chat"
|
||||
with Session(get_sql_engine()) as session:
|
||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||
llm_request = LLMRequest(
|
||||
prompt="你好,世界!",
|
||||
api_provider=api_provider
|
||||
prompt="测试,随便说点什么",
|
||||
api_provider=api_provider,
|
||||
format=generate_example_json(dataset_item)
|
||||
)
|
||||
|
||||
# # 单次调用示例
|
||||
|
1
train/__init__.py
Normal file
1
train/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .model import *
|
Loading…
x
Reference in New Issue
Block a user