Compare commits
6 Commits
90fde639ff
...
868fcd45ba
Author | SHA1 | Date | |
---|---|---|---|
![]() |
868fcd45ba | ||
![]() |
5a21c8598a | ||
![]() |
1e829c9268 | ||
![]() |
9fc3ab904b | ||
![]() |
d827f9758f | ||
![]() |
ff1e9731bc |
@ -6,6 +6,7 @@ from sqlmodel import Session, select
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema import APIProvider
|
from schema import APIProvider
|
||||||
|
from tools import call_openai_api
|
||||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
@ -34,6 +35,7 @@ def dataset_generate_page():
|
|||||||
prompt_content = prompt_data["content"]
|
prompt_content = prompt_data["content"]
|
||||||
prompt_template = PromptTemplate.from_template(prompt_content)
|
prompt_template = PromptTemplate.from_template(prompt_content)
|
||||||
input_variables = prompt_template.input_variables
|
input_variables = prompt_template.input_variables
|
||||||
|
input_variables.remove("document_slice")
|
||||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||||
|
|
||||||
prompt_dropdown = gr.Dropdown(
|
prompt_dropdown = gr.Dropdown(
|
||||||
@ -95,22 +97,30 @@ def dataset_generate_page():
|
|||||||
input_variables = prompt_template.input_variables
|
input_variables = prompt_template.input_variables
|
||||||
input_variables.remove("document_slice")
|
input_variables.remove("document_slice")
|
||||||
dataframe_value = [] if input_variables is None else input_variables
|
dataframe_value = [] if input_variables is None else input_variables
|
||||||
|
dataframe_value = [[var, ""] for var in input_variables]
|
||||||
return selected_prompt, dataframe_value
|
return selected_prompt, dataframe_value
|
||||||
|
|
||||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
||||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
||||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||||
|
prompt = PromptTemplate.from_template(prompt["content"])
|
||||||
with Session(get_sql_engine()) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
|
|
||||||
variables_dict = {}
|
variables_dict = {}
|
||||||
# 正确遍历DataFrame的行数据
|
|
||||||
for _, row in variables_dataframe.iterrows():
|
for _, row in variables_dataframe.iterrows():
|
||||||
var_name = row['变量名'].strip()
|
var_name = row['变量名'].strip()
|
||||||
var_value = row['变量值'].strip()
|
var_value = row['变量值'].strip()
|
||||||
if var_name:
|
if var_name:
|
||||||
variables_dict[var_name] = var_value
|
variables_dict[var_name] = var_value
|
||||||
|
|
||||||
|
# 注入除document_slice以外的所有参数
|
||||||
|
prompt = prompt.partial(**variables_dict)
|
||||||
|
|
||||||
|
print(doc)
|
||||||
|
print(prompt.format(document_slice="test"))
|
||||||
|
print(variables_dict)
|
||||||
|
|
||||||
import time
|
import time
|
||||||
total_steps = rounds
|
total_steps = rounds
|
||||||
for i in range(total_steps):
|
for i in range(total_steps):
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||||
from tools.model import get_model_name
|
from train import get_model_name
|
||||||
|
|
||||||
def model_manage_page():
|
def model_manage_page():
|
||||||
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
workdir = "workdir" # 假设workdir是当前工作目录下的一个文件夹
|
||||||
|
@ -8,7 +8,8 @@ from transformers import TrainerCallback
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
from global_var import get_model, get_tokenizer, get_datasets, get_workdir
|
||||||
from tools import train_model, find_available_port
|
from tools import find_available_port
|
||||||
|
from train import train_model
|
||||||
|
|
||||||
def train_page():
|
def train_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
2
main.py
2
main.py
@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import unsloth
|
import train
|
||||||
from frontend import *
|
from frontend import *
|
||||||
from db import initialize_sqlite_db, initialize_prompt_store
|
from db import initialize_sqlite_db, initialize_prompt_store
|
||||||
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .parse_markdown import parse_markdown
|
from .parse_markdown import *
|
||||||
from .document import *
|
from .document import *
|
||||||
from .json_example import generate_example_json
|
from .json_example import generate_example_json
|
||||||
from .model import *
|
from .port import *
|
||||||
from .port import *
|
from .reasoning import call_openai_api
|
@ -3,17 +3,19 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
|
|
||||||
def generate_example_json(model: type[BaseModel]) -> str:
|
def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||||
"""
|
"""
|
||||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _generate_example(field_type: Any) -> Any:
|
def _generate_example(field_type: Any) -> Any:
|
||||||
origin = get_origin(field_type)
|
origin = get_origin(field_type)
|
||||||
args = get_args(field_type)
|
args = get_args(field_type)
|
||||||
|
|
||||||
if origin is list or origin is List:
|
if origin is list or origin is List:
|
||||||
return [_generate_example(args[0])] if args else []
|
# 生成多个元素,这里生成 3 个
|
||||||
|
result = [_generate_example(args[0]) for _ in range(list_length)] if args else []
|
||||||
|
result.append("......")
|
||||||
|
return result
|
||||||
elif origin is dict or origin is Dict:
|
elif origin is dict or origin is Dict:
|
||||||
if len(args) == 2:
|
if len(args) == 2:
|
||||||
return {"key": _generate_example(args[1])}
|
return {"key": _generate_example(args[1])}
|
||||||
@ -35,21 +37,22 @@ def generate_example_json(model: type[BaseModel]) -> str:
|
|||||||
elif field_type is date:
|
elif field_type is date:
|
||||||
return date.today().isoformat()
|
return date.today().isoformat()
|
||||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||||
return json.loads(generate_example_json(field_type))
|
return json.loads(generate_example_json(field_type, include_optional))
|
||||||
else:
|
else:
|
||||||
# 处理直接类型注解(非泛型)
|
# 处理直接类型注解(非泛型)
|
||||||
if field_type is type(None):
|
if field_type is type(None):
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
if issubclass(field_type, BaseModel):
|
if issubclass(field_type, BaseModel):
|
||||||
return json.loads(generate_example_json(field_type))
|
return json.loads(generate_example_json(field_type, include_optional))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass
|
pass
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
example_data = {}
|
example_data = {}
|
||||||
for field_name, field in model.model_fields.items():
|
for field_name, field in model.model_fields.items():
|
||||||
example_data[field_name] = _generate_example(field.annotation)
|
if include_optional or not isinstance(field.default, type(None)):
|
||||||
|
example_data[field_name] = _generate_example(field.annotation)
|
||||||
|
|
||||||
return json.dumps(example_data, indent=2, default=str)
|
return json.dumps(example_data, indent=2, default=str)
|
||||||
|
|
||||||
|
@ -16,21 +16,30 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
|||||||
|
|
||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
total_tokens = TokensUsage()
|
total_tokens = TokensUsage()
|
||||||
|
prompt = llm_request.prompt
|
||||||
|
round_start = datetime.now(timezone.utc)
|
||||||
|
if llm_request.format:
|
||||||
|
prompt += "\n请以JSON格式返回结果" + llm_request.format
|
||||||
|
|
||||||
for i in range(rounds):
|
for i in range(rounds):
|
||||||
|
round_start = datetime.now(timezone.utc)
|
||||||
try:
|
try:
|
||||||
round_start = datetime.now(timezone.utc)
|
messages = [{"role": "user", "content": prompt}]
|
||||||
messages = [{"role": "user", "content": llm_request.prompt}]
|
create_args = {
|
||||||
response = await client.chat.completions.create(
|
"model": llm_request.api_provider.model_id,
|
||||||
model=llm_request.api_provider.model_id,
|
"messages": messages,
|
||||||
messages=messages,
|
"temperature": llm_parameters.temperature if llm_parameters else None,
|
||||||
temperature=llm_parameters.temperature if llm_parameters else None,
|
"max_tokens": llm_parameters.max_tokens if llm_parameters else None,
|
||||||
max_tokens=llm_parameters.max_tokens if llm_parameters else None,
|
"top_p": llm_parameters.top_p if llm_parameters else None,
|
||||||
top_p=llm_parameters.top_p if llm_parameters else None,
|
"frequency_penalty": llm_parameters.frequency_penalty if llm_parameters else None,
|
||||||
frequency_penalty=llm_parameters.frequency_penalty if llm_parameters else None,
|
"presence_penalty": llm_parameters.presence_penalty if llm_parameters else None,
|
||||||
presence_penalty=llm_parameters.presence_penalty if llm_parameters else None,
|
"seed": llm_parameters.seed if llm_parameters else None
|
||||||
seed=llm_parameters.seed if llm_parameters else None
|
} # 处理format参数
|
||||||
)
|
|
||||||
|
if llm_request.format:
|
||||||
|
create_args["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**create_args)
|
||||||
|
|
||||||
round_end = datetime.now(timezone.utc)
|
round_end = datetime.now(timezone.utc)
|
||||||
duration = (round_end - round_start).total_seconds()
|
duration = (round_end - round_start).total_seconds()
|
||||||
@ -84,15 +93,19 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
|||||||
return llm_request
|
return llm_request
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from json_example import generate_example_json
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from global_var import get_sql_engine, init_global_var
|
from global_var import get_sql_engine, init_global_var
|
||||||
|
from schema import dataset_item
|
||||||
|
|
||||||
init_global_var("workdir")
|
init_global_var("workdir")
|
||||||
api_state = "1 deepseek-chat"
|
api_state = "1 deepseek-chat"
|
||||||
with Session(get_sql_engine()) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
api_provider = session.exec(select(APIProvider).where(APIProvider.id == int(api_state.split(" ")[0]))).first()
|
||||||
llm_request = LLMRequest(
|
llm_request = LLMRequest(
|
||||||
prompt="你好,世界!",
|
prompt="测试,随便说点什么",
|
||||||
api_provider=api_provider
|
api_provider=api_provider,
|
||||||
|
format=generate_example_json(dataset_item)
|
||||||
)
|
)
|
||||||
|
|
||||||
# # 单次调用示例
|
# # 单次调用示例
|
||||||
|
1
train/__init__.py
Normal file
1
train/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .model import *
|
Loading…
x
Reference in New Issue
Block a user