feat(db): 添加数据集存储和读取功能

- 新增 dataset_store.py 文件,实现数据集的存储和读取功能
- 添加 get_all_dataset 函数,用于获取所有数据集
- 使用 tinydb 和 json 进行数据持久化
- 在项目根目录下创建 workdir/dataset 目录用于存储数据集文件
This commit is contained in:
carry 2025-04-09 18:21:27 +08:00
parent 4d77c429bd
commit 202d4c44df
2 changed files with 52 additions and 1 deletions

View File

@ -1,9 +1,11 @@
from .init_db import get_sqlite_engine, initialize_sqlite_db
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
from .dataset_store import get_all_dataset
__all__ = [
"get_sqlite_engine",
"initialize_sqlite_db",
"get_prompt_tinydb",
"initialize_prompt_store"
"initialize_prompt_store",
"get_all_dataset"
]

49
db/dataset_store.py Normal file
View File

@ -0,0 +1,49 @@
import os
import sys
import json
from pathlib import Path
from typing import List
from tinydb import TinyDB
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset import dataset, dataset_item, Q_A
def get_all_dataset(workdir: str) -> List[dataset]:
"""
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
Args:
workdir (str): 工作目录路径
Returns:
List[dataset]: 包含所有数据集对象的列表
"""
dataset_dir = os.path.join(workdir, "dataset")
if not os.path.exists(dataset_dir):
return []
datasets = []
for filename in os.listdir(dataset_dir):
if filename.endswith(".json"):
filepath = os.path.join(dataset_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
datasets.append(dataset(**data))
except (json.JSONDecodeError, Exception) as e:
print(f"Error loading dataset file {filename}: {str(e)}")
continue
return datasets
if __name__ == "__main__":
# 定义工作目录路径
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
# 获取所有数据集
datasets = get_all_dataset(workdir)
# 打印结果
print(f"Found {len(datasets)} datasets:")
for ds in datasets:
print(f"- {ds.name} (ID: {ds.id})")