From 202d4c44df4735bf9aa353297996cd398e06f568 Mon Sep 17 00:00:00 2001 From: carry <2641257231@qq.com> Date: Wed, 9 Apr 2025 18:21:27 +0800 Subject: [PATCH] =?UTF-8?q?feat(db):=20=E6=B7=BB=E5=8A=A0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E5=AD=98=E5=82=A8=E5=92=8C=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 dataset_store.py 文件,实现数据集的存储和读取功能 - 添加 get_all_dataset 函数,用于获取所有数据集 - 使用 tinydb 和 json 进行数据持久化 - 在项目根目录下创建 workdir/dataset 目录用于存储数据集文件 --- db/__init__.py | 4 +++- db/dataset_store.py | 49 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 db/dataset_store.py diff --git a/db/__init__.py b/db/__init__.py index 06161f0..93ec022 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -1,9 +1,11 @@ from .init_db import get_sqlite_engine, initialize_sqlite_db from .prompt_store import get_prompt_tinydb, initialize_prompt_store +from .dataset_store import get_all_dataset __all__ = [ "get_sqlite_engine", "initialize_sqlite_db", "get_prompt_tinydb", - "initialize_prompt_store" + "initialize_prompt_store", + "get_all_dataset" ] \ No newline at end of file diff --git a/db/dataset_store.py b/db/dataset_store.py new file mode 100644 index 0000000..f3c87d3 --- /dev/null +++ b/db/dataset_store.py @@ -0,0 +1,49 @@ +import os +import sys +import json +from pathlib import Path +from typing import List +from tinydb import TinyDB + +# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块 +sys.path.append(str(Path(__file__).resolve().parent.parent)) +from schema.dataset import dataset, dataset_item, Q_A + +def get_all_dataset(workdir: str) -> List[dataset]: + """ + 扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表 + + Args: + workdir (str): 工作目录路径 + + Returns: + List[dataset]: 包含所有数据集对象的列表 + """ + dataset_dir = os.path.join(workdir, "dataset") + if not os.path.exists(dataset_dir): + return [] + + datasets = [] + for filename in os.listdir(dataset_dir): + if filename.endswith(".json"): + filepath = os.path.join(dataset_dir, filename) + try: + with open(filepath, "r", encoding="utf-8") as f: + data = json.load(f) + datasets.append(dataset(**data)) + except (json.JSONDecodeError, Exception) as e: + print(f"Error loading dataset file {filename}: {str(e)}") + continue + + return datasets + + +if __name__ == "__main__": + # 定义工作目录路径 + workdir = os.path.join(os.path.dirname(__file__), "..", "workdir") + # 获取所有数据集 + datasets = get_all_dataset(workdir) + # 打印结果 + print(f"Found {len(datasets)} datasets:") + for ds in datasets: + print(f"- {ds.name} (ID: {ds.id})") \ No newline at end of file