From 2c8e54bb1e1c4de2efc05d98f2fe7df56ee32907 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 9 Apr 2025 20:49:20 +0800 Subject: [PATCH] =?UTF-8?q?feat(dataset):=20=E5=88=9D=E6=AD=A5=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E6=95=B0=E6=8D=AE=E9=9B=86=E7=AE=A1=E7=90=86=E9=A1=B5?= =?UTF-8?q?=E9=9D=A2=E5=92=8C=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- db/dataset_store.py | 19 ++++++++++--------- frontend/dataset_manage_page.py | 19 +++++++++++++++++++ global_var.py | 5 +++-- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/db/dataset_store.py b/db/dataset_store.py index f3c87d3..79e6d40 100644 --- a/db/dataset_store.py +++ b/db/dataset_store.py @@ -3,13 +3,14 @@ import sys import json from pathlib import Path from typing import List -from tinydb import TinyDB +from tinydb import TinyDB, Query +from tinydb.storages import MemoryStorage # 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 sys.path.append(str(Path(__file__).resolve().parent.parent)) from schema.dataset import dataset, dataset_item, Q_A -def get_all_dataset(workdir: str) -> List[dataset]: +def get_all_dataset(workdir: str) -> TinyDB: """ 扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表 @@ -17,25 +18,25 @@ def get_all_dataset(workdir: str) -> List[dataset]: workdir (str): 工作目录路径 Returns: - List[dataset]: 包含所有数据集对象的列表 + TinyDB: 包含所有数据集对象的TinyDB对象 """ dataset_dir = os.path.join(workdir, "dataset") if not os.path.exists(dataset_dir): - return [] + return TinyDB(storage=MemoryStorage) - datasets = [] + db = TinyDB(storage=MemoryStorage) for filename in os.listdir(dataset_dir): if filename.endswith(".json"): filepath = os.path.join(dataset_dir, filename) try: with open(filepath, "r", encoding="utf-8") as f: data = json.load(f) - datasets.append(dataset(**data)) + db.insert(data) except (json.JSONDecodeError, Exception) as e: print(f"Error loading dataset file {filename}: {str(e)}") continue - return datasets + return db if __name__ == "__main__": @@ -45,5 +46,5 @@ if __name__ == "__main__": datasets = get_all_dataset(workdir) # 打印结果 print(f"Found {len(datasets)} datasets:") - for ds in datasets: - print(f"- {ds.name} (ID: {ds.id})") \ No newline at end of file + for ds in datasets.all(): + print(f"- {ds['name']} (ID: {ds['id']})") \ No newline at end of file diff --git a/frontend/dataset_manage_page.py b/frontend/dataset_manage_page.py index d11f31c..c90c5c0 100644 --- a/frontend/dataset_manage_page.py +++ b/frontend/dataset_manage_page.py @@ -1,9 +1,28 @@ import gradio as gr +from global_var import datasets def dataset_manage_page(): with gr.Blocks() as demo: gr.Markdown("## 数据集管理") with gr.Row(): + with gr.Column(): + # 获取数据集列表并设置初始值 + datasets_list = [str(ds["name"]) for ds in datasets.all()] + initial_dataset = datasets_list[0] if datasets_list else None + + dataset_dropdown = gr.Dropdown( + choices=datasets_list, + value=initial_dataset, # 设置初始选中项 + label="选择数据集", + allow_custom_value=True, + interactive=True + ) + dataset_state = gr.State(value=initial_dataset) # 用数据集初始值初始化状态 + with gr.Column(): pass + + # 绑定事件,确保交互时更新状态 + dataset_dropdown.change(lambda x: x, inputs=dataset_dropdown, outputs=dataset_state) + return demo \ No newline at end of file diff --git a/global_var.py b/global_var.py index 6eca116..b35ab07 100644 --- a/global_var.py +++ b/global_var.py @@ -1,6 +1,7 @@ -from db import get_sqlite_engine,get_prompt_tinydb +from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from tools import scan_docs_directory prompt_store = get_prompt_tinydb("workdir") sql_engine = get_sqlite_engine("workdir") -docs = scan_docs_directory("workdir") \ No newline at end of file +docs = scan_docs_directory("workdir") +datasets = get_all_dataset("workdir") \ No newline at end of file