Compare commits

...

3 Commits

Author SHA1 Message Date
carry
8fb9f785b9 feat(frontend): 展示数据集管理页面的问答数据
- 添加 QA 数据集展示组件
- 实现数据集选择时动态加载对应的问答数据
- 优化数据集管理页面布局
2025-04-09 22:23:55 +08:00
carry
2c8e54bb1e feat(dataset): 初步完成数据集管理页面和功能 2025-04-09 20:49:20 +08:00
carry
932d1e2687 refactor(schema): 修改数据集名称默认值
- 将 dataset 类中的 name 字段默认值从 None 改为 ""
- 这个改动确保了数据集名称始终有一个默认的空字符串值,而不是 None,提高了数据一致性和代码健壮性
2025-04-09 19:42:00 +08:00
4 changed files with 61 additions and 14 deletions

View File

@ -3,13 +3,14 @@ import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB
from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> List[dataset]:
def get_all_dataset(workdir: str) -> TinyDB:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
@ -17,25 +18,25 @@ def get_all_dataset(workdir: str) -> List[dataset]:
workdir (str): 工作目录路径
Returns:
List[dataset]: 包含所有数据集对象的列表
TinyDB: 包含所有数据集对象的TinyDB对象
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return []
return TinyDB(storage=MemoryStorage)
datasets = []
db = TinyDB(storage=MemoryStorage)
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
datasets.append(dataset(**data))
db.insert(data)
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return datasets
return db
if __name__ == "__main__":
@ -45,5 +46,5 @@ if __name__ == "__main__":
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets:
print(f"- {ds.name} (ID: {ds.id})")
for ds in datasets.all():
print(f"- {ds['name']} (ID: {ds['id']})")

View File

@ -1,9 +1,54 @@
import gradio as gr
from global_var import datasets
from tinydb import Query
def dataset_manage_page():
with gr.Blocks() as demo:
gr.Markdown("## 数据集管理")
with gr.Row():
with gr.Column():
pass
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in datasets.all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
# 添加数据集展示组件
qa_dataset = gr.Dataset(
components=["text", "text"],
label="问答数据",
headers=["问题", "答案"],
samples=[["示例问题", "示例答案"]]
)
def update_qa_display(dataset_name):
if not dataset_name:
return {"samples": [], "__type__": "update"}
# 从数据库获取数据集
Dataset = Query()
ds = datasets.get(Dataset.name == dataset_name)
if not ds:
return {"samples": [], "__type__": "update"}
# 提取所有Q_A数据
qa_list = []
for item in ds["dataset_items"]:
for qa in item["message"]:
qa_list.append([qa["question"], qa["answer"]])
return {"samples": qa_list, "__type__": "update"}
# 绑定事件更新QA数据显示
dataset_dropdown.change(
update_qa_display,
inputs=dataset_dropdown,
outputs=qa_dataset
)
return demo

View File

@ -1,6 +1,7 @@
from db import get_sqlite_engine,get_prompt_tinydb
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir")
docs = scan_docs_directory("workdir")
docs = scan_docs_directory("workdir")
datasets = get_all_dataset("workdir")

View File

@ -19,7 +19,7 @@ class dataset_item(BaseModel):
class dataset(BaseModel):
id: Optional[int] = Field(default=None, description="数据集ID")
name: Optional[str] = Field(default=None, description="数据集名称")
name: str = Field(default="", description="数据集名称")
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
description: Optional[str] = Field(default="", description="数据集描述")