
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件 - 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用 - 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
import os
|
||
import sys
|
||
import json
|
||
from pathlib import Path
|
||
from typing import List
|
||
from tinydb import TinyDB, Query
|
||
from tinydb.storages import MemoryStorage
|
||
|
||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||
from schema.dataset import Dataset, DatasetItem, Q_A
|
||
|
||
def get_all_dataset(workdir: str) -> TinyDB:
|
||
"""
|
||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||
|
||
Args:
|
||
workdir (str): 工作目录路径
|
||
|
||
Returns:
|
||
TinyDB: 包含所有数据集对象的TinyDB对象
|
||
"""
|
||
dataset_dir = os.path.join(workdir, "dataset")
|
||
if not os.path.exists(dataset_dir):
|
||
return TinyDB(storage=MemoryStorage)
|
||
|
||
db = TinyDB(storage=MemoryStorage)
|
||
for filename in os.listdir(dataset_dir):
|
||
if filename.endswith(".json"):
|
||
filepath = os.path.join(dataset_dir, filename)
|
||
try:
|
||
with open(filepath, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
db.insert(data)
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
print(f"Error loading dataset file {filename}: {str(e)}")
|
||
continue
|
||
|
||
return db
|
||
|
||
|
||
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||
"""
|
||
将TinyDB中的数据集保存为单独的json文件
|
||
|
||
Args:
|
||
db (TinyDB): 包含数据集对象的TinyDB实例
|
||
workdir (str): 工作目录路径
|
||
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||
"""
|
||
dataset_dir = os.path.join(workdir, "dataset")
|
||
os.makedirs(dataset_dir, exist_ok=True)
|
||
|
||
datasets = db.all() if name is None else db.search(Query().name == name)
|
||
|
||
for dataset in datasets:
|
||
try:
|
||
filename = f"{dataset.get(dataset['name'])}.json"
|
||
filepath = os.path.join(dataset_dir, filename)
|
||
|
||
with open(filepath, "w", encoding="utf-8") as f:
|
||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
# 定义工作目录路径
|
||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||
# 获取所有数据集
|
||
datasets = get_all_dataset(workdir)
|
||
# 打印结果
|
||
print(f"Found {len(datasets)} datasets:")
|
||
for ds in datasets.all():
|
||
print(f"- {ds['name']} (ID: {ds['id']})")
|
||
|
||
# 询问要保存的数据集名称
|
||
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||
|
||
# 保存数据集到文件
|
||
save_dataset(datasets, workdir, name)
|
||
print(f"Datasets {'all' if name is None else name} saved to json files") |