feat(dataset): 初步完成数据集管理页面和功能

This commit is contained in:
carry 2025-04-09 20:49:20 +08:00
parent 932d1e2687
commit 2c8e54bb1e
3 changed files with 32 additions and 11 deletions

View File

@ -3,13 +3,14 @@ import sys
import json import json
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from tinydb import TinyDB from tinydb import TinyDB, Query
from tinydb.storages import MemoryStorage
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> List[dataset]: def get_all_dataset(workdir: str) -> TinyDB:
""" """
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表 扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
@ -17,25 +18,25 @@ def get_all_dataset(workdir: str) -> List[dataset]:
workdir (str): 工作目录路径 workdir (str): 工作目录路径
Returns: Returns:
List[dataset]: 包含所有数据集对象的列表 TinyDB: 包含所有数据集对象的TinyDB对象
""" """
dataset_dir = os.path.join(workdir, "dataset") dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
return [] return TinyDB(storage=MemoryStorage)
datasets = [] db = TinyDB(storage=MemoryStorage)
for filename in os.listdir(dataset_dir): for filename in os.listdir(dataset_dir):
if filename.endswith(".json"): if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename) filepath = os.path.join(dataset_dir, filename)
try: try:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
datasets.append(dataset(**data)) db.insert(data)
except (json.JSONDecodeError, Exception) as e: except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}") print(f"Error loading dataset file {filename}: {str(e)}")
continue continue
return datasets return db
if __name__ == "__main__": if __name__ == "__main__":
@ -45,5 +46,5 @@ if __name__ == "__main__":
datasets = get_all_dataset(workdir) datasets = get_all_dataset(workdir)
# 打印结果 # 打印结果
print(f"Found {len(datasets)} datasets:") print(f"Found {len(datasets)} datasets:")
for ds in datasets: for ds in datasets.all():
print(f"- {ds.name} (ID: {ds.id})") print(f"- {ds['name']} (ID: {ds['id']})")

View File

@ -1,9 +1,28 @@
import gradio as gr import gradio as gr
from global_var import datasets
def dataset_manage_page(): def dataset_manage_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("## 数据集管理") gr.Markdown("## 数据集管理")
with gr.Row(): with gr.Row():
with gr.Column():
# 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in datasets.all()]
initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown(
choices=datasets_list,
value=initial_dataset, # 设置初始选中项
label="选择数据集",
allow_custom_value=True,
interactive=True
)
dataset_state = gr.State(value=initial_dataset) # 用数据集初始值初始化状态
with gr.Column(): with gr.Column():
pass pass
# 绑定事件,确保交互时更新状态
dataset_dropdown.change(lambda x: x, inputs=dataset_dropdown, outputs=dataset_state)
return demo return demo

View File

@ -1,6 +1,7 @@
from db import get_sqlite_engine,get_prompt_tinydb from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir") prompt_store = get_prompt_tinydb("workdir")
sql_engine = get_sqlite_engine("workdir") sql_engine = get_sqlite_engine("workdir")
docs = scan_docs_directory("workdir") docs = scan_docs_directory("workdir")
datasets = get_all_dataset("workdir")