Compare commits

..

3 Commits

Author SHA1 Message Date
carry
8fb9f785b9 feat(frontend): 展示数据集管理页面的问答数据
- 添加 QA 数据集展示组件
- 实现数据集选择时动态加载对应的问答数据
- 优化数据集管理页面布局
2025-04-09 22:23:55 +08:00
carry
2c8e54bb1e feat(dataset): 初步完成数据集管理页面和功能 2025-04-09 20:49:20 +08:00
carry
932d1e2687 refactor(schema): 修改数据集名称默认值
- 将 dataset 类中的 name 字段默认值从 None 改为 ""
- 这个改动确保了数据集名称始终有一个默认的空字符串值,而不是 None,提高了数据一致性和代码健壮性
2025-04-09 19:42:00 +08:00
4 changed files with 61 additions and 14 deletions

View File

@ -3,13 +3,14 @@ import sys
import json import json
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from tinydb import TinyDB from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> List[dataset]: def get_all_dataset(workdir: str) -> TinyDB:
""" """
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表 扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
@ -17,25 +18,25 @@ def get_all_dataset(workdir: str) -> List[dataset]:
workdir (str): 工作目录路径 workdir (str): 工作目录路径
Returns: Returns:
List[dataset]: 包含所有数据集对象的列表 TinyDB: 包含所有数据集对象的TinyDB对象
""" """
dataset_dir = os.path.join(workdir, "dataset") dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
return [] return TinyDB(storage=MemoryStorage)
datasets = [] db = TinyDB(storage=MemoryStorage)
for filename in os.listdir(dataset_dir): for filename in os.listdir(dataset_dir):
if filename.endswith(".json"): if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename) filepath = os.path.join(dataset_dir, filename)
try: try:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
datasets.append(dataset(**data)) db.insert(data)
except (json.JSONDecodeError, Exception) as e: except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}") print(f"Error loading dataset file {filename}: {str(e)}")
continue continue
return datasets return db
if __name__ == "__main__": if __name__ == "__main__":
@ -45,5 +46,5 @@ if __name__ == "__main__":
datasets = get_all_dataset(workdir) datasets = get_all_dataset(workdir)
# 打印结果 # 打印结果
print(f"Found {len(datasets)} datasets:") print(f"Found {len(datasets)} datasets:")
for ds in datasets: for ds in datasets.all():
print(f"- {ds.name} (ID: {ds.id})") print(f"- {ds['name']} (ID: {ds['id']})")

View File

@ -1,9 +1,54 @@
import gradio as gr import gradio as gr
from global_var import datasets
from tinydb import Query
def dataset_manage_page(): def dataset_manage_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 数据集管理") gr.Markdown("## 数据集管理")
with gr.Row(): with gr.Row():
with gr.Column(): # 获取数据集列表并设置初始值
pass datasets_list = [str(ds["name"]) for ds in datasets.all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 添加数据集展示组件
qa_dataset = gr.Dataset(
components=["text", "text"],
label="问答数据",
headers=["问题", "答案"],
samples=[["示例问题", "示例答案"]]
)
def update_qa_display(dataset_name):
if not dataset_name:
return {"samples": [], "__type__": "update"}
# 从数据库获取数据集
Dataset = Query()
ds = datasets.get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}
# 提取所有Q_A数据
qa_list = []
for item in ds["dataset_items"]:
for qa in item["message"]:
qa_list.append([qa["question"], qa["answer"]])
return {"samples": qa_list, "__type__": "update"}
# 绑定事件更新QA数据显示
dataset_dropdown.change(
update_qa_display,
inputs=dataset_dropdown,
outputs=qa_dataset
)
return demo return demo

View File

@ -1,6 +1,7 @@
from db import get_sqlite_engine,get_prompt_tinydb from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir") prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir") sql_engine = get_sqlite_engine("workdir")
docs = scan_docs_directory("workdir") docs = scan_docs_directory("workdir")
datasets = get_all_dataset("workdir")

View File

@ -19,7 +19,7 @@ class dataset_item(BaseModel):
class dataset(BaseModel): class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID") id: Optional[int] = Field(default=None, description="数据集ID")
name: Optional[str] = Field(default=None, description="数据集名称") name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID") model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档") source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述") description: Optional[str] = Field(default="", description="数据集描述")