Compare commits

..

No commits in common. "8fb9f785b9f663cc5f15b7aad0c7baf3d4bfed8c" and "202d4c44df4735bf9aa353297996cd398e06f568" have entirely different histories.

4 changed files with 14 additions and 61 deletions

View File

@ -3,14 +3,13 @@ import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
from tinydb import TinyDB
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> TinyDB:
def get_all_dataset(workdir: str) -> List[dataset]:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
@ -18,25 +17,25 @@ def get_all_dataset(workdir: str) -> TinyDB:
workdir (str): 工作目录路径
Returns:
TinyDB: 包含所有数据集对象的TinyDB对象
List[dataset]: 包含所有数据集对象的列表
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return TinyDB(storage=MemoryStorage)
return []
db = TinyDB(storage=MemoryStorage)
datasets = []
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
db.insert(data)
datasets.append(dataset(**data))
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return db
return datasets
if __name__ == "__main__":
@ -46,5 +45,5 @@ if __name__ == "__main__":
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})")
for ds in datasets:
print(f"- {ds.name} (ID: {ds.id})")

View File

@ -1,54 +1,9 @@
import gradio as gr
from global_var import datasets
from tinydb import Query
def dataset_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集管理")
with gr.Row():
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in datasets.all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 添加数据集展示组件
qa_dataset = gr.Dataset(
components=["text", "text"],
label="问答数据",
headers=["问题", "答案"],
samples=[["示例问题", "示例答案"]]
)
def update_qa_display(dataset_name):
if not dataset_name:
return {"samples": [], "__type__": "update"}
# 从数据库获取数据集
Dataset = Query()
ds = datasets.get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}
# 提取所有Q_A数据
qa_list = []
for item in ds["dataset_items"]:
for qa in item["message"]:
qa_list.append([qa["question"], qa["answer"]])
return {"samples": qa_list, "__type__": "update"}
# 绑定事件更新QA数据显示
dataset_dropdown.change(
update_qa_display,
inputs=dataset_dropdown,
outputs=qa_dataset
)
with gr.Column():
pass
return demo

View File

@ -1,7 +1,6 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from db import get_sqlite_engine,get_prompt_tinydb
from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir")
docs = scan_docs_directory("workdir")
datasets = get_all_dataset("workdir")
docs = scan_docs_directory("workdir")

View File

@ -19,7 +19,7 @@ class dataset_item(BaseModel):
class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: str = Field(default="", description="数据集名称")
name: Optional[str] = Field(default=None, description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述")