gzhu-biyesheji/db/dataset_store.py
carry 202d4c44df feat(db): 添加数据集存储和读取功能
- 新增 dataset_store.py 文件,实现数据集的存储和读取功能
- 添加 get_all_dataset 函数,用于获取所有数据集
- 使用 tinydb 和 json 进行数据持久化
- 在项目根目录下创建 workdir/dataset 目录用于存储数据集文件
2025-04-09 18:21:27 +08:00

49 lines
1.5 KiB
Python

import os
import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> List[dataset]:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
Args:
workdir (str): 工作目录路径
Returns:
List[dataset]: 包含所有数据集对象的列表
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return []
datasets = []
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
datasets.append(dataset(**data))
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return datasets
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取所有数据集
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets:
print(f"- {ds.name} (ID: {ds.id})")