feat(db): 添加数据集存储和读取功能
- 新增 dataset_store.py 文件,实现数据集的存储和读取功能 - 添加 get_all_dataset 函数,用于获取所有数据集 - 使用 tinydb 和 json 进行数据持久化 - 在项目根目录下创建 workdir/dataset 目录用于存储数据集文件
This commit is contained in:
parent
4d77c429bd
commit
202d4c44df
@ -1,9 +1,11 @@
|
||||
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||
from .dataset_store import get_all_dataset
|
||||
|
||||
__all__ = [
|
||||
"get_sqlite_engine",
|
||||
"initialize_sqlite_db",
|
||||
"get_prompt_tinydb",
|
||||
"initialize_prompt_store"
|
||||
"initialize_prompt_store",
|
||||
"get_all_dataset"
|
||||
]
|
49
db/dataset_store.py
Normal file
49
db/dataset_store.py
Normal file
@ -0,0 +1,49 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tinydb import TinyDB
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset import dataset, dataset_item, Q_A
|
||||
|
||||
def get_all_dataset(workdir: str) -> List[dataset]:
|
||||
"""
|
||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径
|
||||
|
||||
Returns:
|
||||
List[dataset]: 包含所有数据集对象的列表
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
if not os.path.exists(dataset_dir):
|
||||
return []
|
||||
|
||||
datasets = []
|
||||
for filename in os.listdir(dataset_dir):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
datasets.append(dataset(**data))
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
||||
continue
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取所有数据集
|
||||
datasets = get_all_dataset(workdir)
|
||||
# 打印结果
|
||||
print(f"Found {len(datasets)} datasets:")
|
||||
for ds in datasets:
|
||||
print(f"- {ds.name} (ID: {ds.id})")
|
Loading…
x
Reference in New Issue
Block a user