Compare commits
9 Commits
868fcd45ba
...
0a4efa5641
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0a4efa5641 | ||
![]() |
994d600221 | ||
![]() |
d5774eee0c | ||
![]() |
87501c9353 | ||
![]() |
5fc3b4950b | ||
![]() |
c28e4819d9 | ||
![]() |
e7cf51d662 | ||
![]() |
4c9caff668 | ||
![]() |
9236f49b36 |
@ -1,11 +1,12 @@
|
||||
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||
from .dataset_store import get_all_dataset
|
||||
from .dataset_store import get_all_dataset, save_dataset
|
||||
|
||||
__all__ = [
|
||||
"load_sqlite_engine",
|
||||
"initialize_sqlite_db",
|
||||
"get_prompt_tinydb",
|
||||
"initialize_prompt_store",
|
||||
"get_all_dataset"
|
||||
"get_all_dataset",
|
||||
"save_dataset"
|
||||
]
|
@ -8,7 +8,7 @@ from tinydb.storages import MemoryStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset import dataset, dataset_item, Q_A
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
|
||||
def get_all_dataset(workdir: str) -> TinyDB:
|
||||
"""
|
||||
@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
|
||||
return db
|
||||
|
||||
|
||||
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||||
"""
|
||||
将TinyDB中的数据集保存为单独的json文件
|
||||
|
||||
Args:
|
||||
db (TinyDB): 包含数据集对象的TinyDB实例
|
||||
workdir (str): 工作目录路径
|
||||
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
datasets = db.all() if name is None else db.search(Query().name == name)
|
||||
|
||||
for dataset in datasets:
|
||||
try:
|
||||
filename = f"{dataset.get(dataset['name'])}.json"
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
@ -47,4 +71,11 @@ if __name__ == "__main__":
|
||||
# 打印结果
|
||||
print(f"Found {len(datasets)} datasets:")
|
||||
for ds in datasets.all():
|
||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||
|
||||
# 询问要保存的数据集名称
|
||||
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||||
|
||||
# 保存数据集到文件
|
||||
save_dataset(datasets, workdir, name)
|
||||
print(f"Datasets {'all' if name is None else name} saved to json files")
|
@ -1,13 +1,18 @@
|
||||
import gradio as gr
|
||||
import sys
|
||||
import json
|
||||
from tinydb import Query
|
||||
from pathlib import Path
|
||||
from langchain.prompts import PromptTemplate
|
||||
from sqlmodel import Session, select
|
||||
from schema import Dataset, DatasetItem, Q_A
|
||||
from db.dataset_store import get_all_dataset
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import APIProvider
|
||||
from tools import call_openai_api
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine
|
||||
from schema import APIProvider, LLMRequest, LLMResponse, DatasetItem
|
||||
from db import save_dataset
|
||||
from tools import call_openai_api, process_markdown_file, generate_json_example
|
||||
from global_var import get_docs, get_prompt_store, get_sql_engine, get_datasets, get_workdir
|
||||
|
||||
def dataset_generate_page():
|
||||
with gr.Blocks() as demo:
|
||||
@ -16,13 +21,6 @@ def dataset_generate_page():
|
||||
with gr.Column(scale=1):
|
||||
docs_list = [str(doc.name) for doc in get_docs()]
|
||||
initial_doc = docs_list[0] if docs_list else None
|
||||
doc_dropdown = gr.Dropdown(
|
||||
choices=docs_list,
|
||||
value=initial_doc,
|
||||
label="选择文档",
|
||||
interactive=True
|
||||
)
|
||||
doc_choice = gr.State(value=initial_doc)
|
||||
prompts = get_prompt_store().all()
|
||||
prompt_list = [f"{p['id']} {p['name']}" for p in prompts]
|
||||
initial_prompt = prompt_list[0] if prompt_list else None
|
||||
@ -37,14 +35,6 @@ def dataset_generate_page():
|
||||
input_variables = prompt_template.input_variables
|
||||
input_variables.remove("document_slice")
|
||||
initial_dataframe_value = [[var, ""] for var in input_variables]
|
||||
|
||||
prompt_dropdown = gr.Dropdown(
|
||||
choices=prompt_list,
|
||||
value=initial_prompt,
|
||||
label="选择模板",
|
||||
interactive=True
|
||||
)
|
||||
prompt_choice = gr.State(value=initial_prompt)
|
||||
# 从数据库获取API Provider列表
|
||||
with Session(get_sql_engine()) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
@ -57,8 +47,18 @@ def dataset_generate_page():
|
||||
label="选择API",
|
||||
interactive=True
|
||||
)
|
||||
api_choice = gr.State(value=initial_api)
|
||||
|
||||
doc_dropdown = gr.Dropdown(
|
||||
choices=docs_list,
|
||||
value=initial_doc,
|
||||
label="选择文档",
|
||||
interactive=True
|
||||
)
|
||||
prompt_dropdown = gr.Dropdown(
|
||||
choices=prompt_list,
|
||||
value=initial_prompt,
|
||||
label="选择模板",
|
||||
interactive=True
|
||||
)
|
||||
rounds_input = gr.Number(
|
||||
value=1,
|
||||
label="生成轮次",
|
||||
@ -67,10 +67,25 @@ def dataset_generate_page():
|
||||
step=1,
|
||||
interactive=True
|
||||
)
|
||||
concurrency_input = gr.Number(
|
||||
value=1,
|
||||
label="并发数",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=1,
|
||||
interactive=True,
|
||||
visible=False
|
||||
)
|
||||
dataset_name_input = gr.Textbox(
|
||||
label="数据集名称",
|
||||
placeholder="输入数据集保存名称",
|
||||
interactive=True
|
||||
)
|
||||
prompt_choice = gr.State(value=initial_prompt)
|
||||
generate_button = gr.Button("生成数据集",variant="primary")
|
||||
|
||||
doc_choice = gr.State(value=initial_doc)
|
||||
output_text = gr.Textbox(label="生成结果", interactive=False)
|
||||
|
||||
api_choice = gr.State(value=initial_api)
|
||||
with gr.Column(scale=2):
|
||||
variables_dataframe = gr.Dataframe(
|
||||
headers=["变量名", "变量值"],
|
||||
@ -79,8 +94,6 @@ def dataset_generate_page():
|
||||
label="变量列表",
|
||||
value=initial_dataframe_value # 设置初始化数据
|
||||
)
|
||||
|
||||
|
||||
def on_doc_change(selected_doc):
|
||||
return selected_doc
|
||||
|
||||
@ -100,8 +113,14 @@ def dataset_generate_page():
|
||||
dataframe_value = [[var, ""] for var in input_variables]
|
||||
return selected_prompt, dataframe_value
|
||||
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, progress=gr.Progress()):
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0].markdown_files
|
||||
def on_generate_click(doc_state, prompt_state, api_state, variables_dataframe, rounds, concurrency, dataset_name, progress=gr.Progress()):
|
||||
dataset_db = get_datasets()
|
||||
if not dataset_db.search(Query().name == dataset_name):
|
||||
raise gr.Error("数据集名称已存在")
|
||||
|
||||
doc = [i for i in get_docs() if i.name == doc_state][0]
|
||||
doc_files = doc.markdown_files
|
||||
document_slice_list = [process_markdown_file(doc) for doc in doc_files]
|
||||
prompt = [i for i in get_prompt_store().all() if i["id"] == int(prompt_state.split(" ")[0])][0]
|
||||
prompt = PromptTemplate.from_template(prompt["content"])
|
||||
with Session(get_sql_engine()) as session:
|
||||
@ -114,23 +133,42 @@ def dataset_generate_page():
|
||||
if var_name:
|
||||
variables_dict[var_name] = var_value
|
||||
|
||||
# 注入除document_slice以外的所有参数
|
||||
prompt = prompt.partial(**variables_dict)
|
||||
|
||||
|
||||
print(doc)
|
||||
print(prompt.format(document_slice="test"))
|
||||
print(variables_dict)
|
||||
|
||||
import time
|
||||
total_steps = rounds
|
||||
for i in range(total_steps):
|
||||
# 模拟每个步骤的工作负载
|
||||
time.sleep(0.5)
|
||||
dataset = Dataset(
|
||||
name=dataset_name,
|
||||
model_id=[api_provider.model_id],
|
||||
source_doc=doc,
|
||||
dataset_items=[]
|
||||
)
|
||||
|
||||
for document_slice in document_slice_list:
|
||||
request = LLMRequest(api_provider=api_provider,
|
||||
prompt=prompt.format(document_slice=document_slice),
|
||||
format=generate_json_example(DatasetItem))
|
||||
call_openai_api(request, rounds)
|
||||
|
||||
current_progress = (i + 1) / total_steps
|
||||
progress(current_progress, desc=f"处理步骤 {i + 1}/{total_steps}")
|
||||
for resp in request.response:
|
||||
try:
|
||||
content = json.loads(resp.content)
|
||||
dataset_item = DatasetItem(
|
||||
message=[Q_A(
|
||||
question=content.get("question", ""),
|
||||
answer=content.get("answer", "")
|
||||
)]
|
||||
)
|
||||
dataset.dataset_items.append(dataset_item)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse response: {e}")
|
||||
|
||||
return "all done"
|
||||
# 保存数据集到TinyDB
|
||||
dataset_db.insert(dataset.model_dump())
|
||||
|
||||
save_dataset(dataset_db,get_workdir(),dataset_name)
|
||||
|
||||
return f"数据集 {dataset_name} 生成完成,共 {len(dataset.dataset_items)} 条数据"
|
||||
|
||||
doc_dropdown.change(on_doc_change, inputs=doc_dropdown, outputs=doc_choice)
|
||||
prompt_dropdown.change(on_prompt_change, inputs=prompt_dropdown, outputs=[prompt_choice, variables_dataframe])
|
||||
@ -138,7 +176,7 @@ def dataset_generate_page():
|
||||
|
||||
generate_button.click(
|
||||
on_generate_click,
|
||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input],
|
||||
inputs=[doc_choice, prompt_choice, api_choice, variables_dataframe, rounds_input, concurrency_input, dataset_name_input],
|
||||
outputs=output_text
|
||||
)
|
||||
|
||||
|
@ -74,7 +74,7 @@ def train_page():
|
||||
|
||||
# 动态生成 TensorBoard iframe
|
||||
tensorboard_url = f"http://localhost:{tensorboard_port}"
|
||||
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="500px"></iframe>'
|
||||
tensorboard_iframe_value = f'<iframe src="{tensorboard_url}" width="100%" height="1000px"></iframe>'
|
||||
yield "训练开始...", tensorboard_iframe_value # 返回两个值,分别对应 textbox 和 html
|
||||
|
||||
try:
|
||||
|
@ -37,10 +37,6 @@ def get_docs():
|
||||
def get_datasets():
|
||||
return _datasets
|
||||
|
||||
def set_datasets(new_datasets):
|
||||
global _datasets
|
||||
_datasets = new_datasets
|
||||
|
||||
def get_model():
|
||||
return _model
|
||||
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class doc(BaseModel):
|
||||
class Doc(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="文档ID")
|
||||
name: str = Field(default="", description="文档名称")
|
||||
path: str = Field(default="", description="文档路径")
|
||||
@ -13,18 +13,18 @@ class Q_A(BaseModel):
|
||||
question: str = Field(default="", min_length=1,description="问题")
|
||||
answer: str = Field(default="", min_length=1, description="答案")
|
||||
|
||||
class dataset_item(BaseModel):
|
||||
class DatasetItem(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集项ID")
|
||||
message: list[Q_A] = Field(description="数据集项内容")
|
||||
|
||||
class dataset(BaseModel):
|
||||
class Dataset(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||
name: str = Field(default="", description="数据集名称")
|
||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
|
||||
source_doc: Optional[Doc] = Field(default=None, description="数据集来源文档")
|
||||
description: Optional[str] = Field(default="", description="数据集描述")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
dataset_items: list[dataset_item] = Field(default_factory=list, description="数据集项列表")
|
||||
dataset_items: list[DatasetItem] = Field(default_factory=list, description="数据集项列表")
|
@ -33,7 +33,7 @@ class LLMResponse(SQLModel):
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: TokensUsage = Field(default_factory=TokensUsage, description="token使用信息")
|
||||
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
||||
content: str = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: Optional[LLMParameters] = Field(default=None, description="LLM参数")
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .parse_markdown import *
|
||||
from .document import *
|
||||
from .json_example import generate_example_json
|
||||
from .json_example import generate_json_example
|
||||
from .port import *
|
||||
from .reasoning import call_openai_api
|
@ -1,19 +1,19 @@
|
||||
from typing import List
|
||||
from schema.dataset import dataset, dataset_item, Q_A
|
||||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||||
import json
|
||||
|
||||
def convert_json_to_dataset(json_data: List[dict]) -> dataset:
|
||||
def convert_json_to_dataset(json_data: List[dict]) -> Dataset:
|
||||
# 将JSON数据转换为dataset格式
|
||||
dataset_items = []
|
||||
item_id = 1 # 自增ID计数器
|
||||
for item in json_data:
|
||||
qa = Q_A(question=item["question"], answer=item["answer"])
|
||||
dataset_item_obj = dataset_item(id=item_id, message=[qa])
|
||||
dataset_item_obj = DatasetItem(id=item_id, message=[qa])
|
||||
dataset_items.append(dataset_item_obj)
|
||||
item_id += 1 # ID自增
|
||||
|
||||
# 创建dataset对象
|
||||
result_dataset = dataset(
|
||||
result_dataset = Dataset(
|
||||
name="Converted Dataset",
|
||||
model_id=None,
|
||||
description="Dataset converted from JSON",
|
||||
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import doc
|
||||
from schema import Doc
|
||||
|
||||
def scan_docs_directory(workdir: str):
|
||||
docs_dir = os.path.join(workdir, "docs")
|
||||
@ -21,7 +21,7 @@ def scan_docs_directory(workdir: str):
|
||||
for file in files:
|
||||
if file.endswith(".md"):
|
||||
markdown_files.append(os.path.join(root, file))
|
||||
to_return.append(doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
||||
to_return.append(Doc(name=doc_name, path=doc_path, markdown_files=markdown_files))
|
||||
|
||||
return to_return
|
||||
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
def generate_example_json(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||
def generate_json_example(model: type[BaseModel], include_optional: bool = False,list_length = 2) -> str:
|
||||
"""
|
||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||
"""
|
||||
@ -37,14 +37,14 @@ def generate_example_json(model: type[BaseModel], include_optional: bool = False
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_example_json(field_type, include_optional))
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
else:
|
||||
# 处理直接类型注解(非泛型)
|
||||
if field_type is type(None):
|
||||
return None
|
||||
try:
|
||||
if issubclass(field_type, BaseModel):
|
||||
return json.loads(generate_example_json(field_type, include_optional))
|
||||
return json.loads(generate_json_example(field_type, include_optional))
|
||||
except TypeError:
|
||||
pass
|
||||
return "unknown"
|
||||
@ -61,7 +61,7 @@ if __name__ == "__main__":
|
||||
from pathlib import Path
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import dataset
|
||||
from schema import Dataset
|
||||
|
||||
print("示例 JSON:")
|
||||
print(generate_example_json(dataset))
|
||||
print(generate_json_example(Dataset))
|
||||
|
@ -68,7 +68,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=response.id,
|
||||
tokens_usage=tokens_usage,
|
||||
response_content={"content": response.choices[0].message.content},
|
||||
content = response.choices[0].message.content,
|
||||
total_duration=duration,
|
||||
llm_parameters=llm_parameters
|
||||
))
|
||||
@ -79,7 +79,7 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
||||
|
||||
llm_request.response.append(LLMResponse(
|
||||
response_id=f"error-round-{i+1}",
|
||||
response_content={"error": str(e)},
|
||||
content={"error": str(e)},
|
||||
total_duration=duration
|
||||
))
|
||||
if llm_request.error is None:
|
||||
@ -93,10 +93,10 @@ async def call_openai_api(llm_request: LLMRequest, rounds: int = 1, llm_paramete
|
||||
return llm_request
|
||||
|
||||
if __name__ == "__main__":
|
||||
from json_example import generate_example_json
|
||||
from json_example import generate_json_example
|
||||
from sqlmodel import Session, select
|
||||
from global_var import get_sql_engine, init_global_var
|
||||
from schema import dataset_item
|
||||
from schema import DatasetItem
|
||||
|
||||
init_global_var("workdir")
|
||||
api_state = "1 deepseek-chat"
|
||||
@ -105,7 +105,7 @@ if __name__ == "__main__":
|
||||
llm_request = LLMRequest(
|
||||
prompt="测试,随便说点什么",
|
||||
api_provider=api_provider,
|
||||
format=generate_example_json(dataset_item)
|
||||
format=generate_json_example(DatasetItem)
|
||||
)
|
||||
|
||||
# # 单次调用示例
|
||||
@ -120,4 +120,4 @@ if __name__ == "__main__":
|
||||
print(f"\n3次调用结果 - 总耗时: {result.total_duration:.2f}s")
|
||||
print(f"总token使用: prompt={result.total_tokens_usage.prompt_tokens}, completion={result.total_tokens_usage.completion_tokens}")
|
||||
for i, resp in enumerate(result.response, 1):
|
||||
print(f"响应{i}: {resp.response_content}")
|
||||
print(f"响应{i}: {resp.content}")
|
Loading…
x
Reference in New Issue
Block a user