feat(db): 添加数据集导出功能
- 新增 save_dataset 函数,用于将 TinyDB 中的数据集保存为单独的 JSON 文件 - 更新 db/__init__.py,添加 get_dataset_tinydb 函数的引用 - 修改 db/dataset_store.py,实现 save_dataset 函数并添加相关逻辑
This commit is contained in:
parent
87501c9353
commit
d5774eee0c
@ -1,11 +1,12 @@
|
|||||||
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
from .init_db import load_sqlite_engine, initialize_sqlite_db
|
||||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||||
from .dataset_store import get_all_dataset
|
from .dataset_store import get_all_dataset, save_dataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_sqlite_engine",
|
"load_sqlite_engine",
|
||||||
"initialize_sqlite_db",
|
"initialize_sqlite_db",
|
||||||
"get_prompt_tinydb",
|
"get_prompt_tinydb",
|
||||||
"initialize_prompt_store",
|
"initialize_prompt_store",
|
||||||
"get_all_dataset"
|
"get_all_dataset",
|
||||||
|
"save_dataset"
|
||||||
]
|
]
|
@ -39,6 +39,30 @@ def get_all_dataset(workdir: str) -> TinyDB:
|
|||||||
return db
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset(db: TinyDB, workdir: str, name: str = None) -> None:
|
||||||
|
"""
|
||||||
|
将TinyDB中的数据集保存为单独的json文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (TinyDB): 包含数据集对象的TinyDB实例
|
||||||
|
workdir (str): 工作目录路径
|
||||||
|
name (str, optional): 要保存的数据集名称,None表示保存所有
|
||||||
|
"""
|
||||||
|
dataset_dir = os.path.join(workdir, "dataset")
|
||||||
|
os.makedirs(dataset_dir, exist_ok=True)
|
||||||
|
|
||||||
|
datasets = db.all() if name is None else db.search(Query().name == name)
|
||||||
|
|
||||||
|
for dataset in datasets:
|
||||||
|
try:
|
||||||
|
filename = f"{dataset.get(dataset['name'])}.json"
|
||||||
|
filepath = os.path.join(dataset_dir, filename)
|
||||||
|
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving dataset {dataset.get('id', 'unknown')}: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 定义工作目录路径
|
# 定义工作目录路径
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
@ -47,4 +71,11 @@ if __name__ == "__main__":
|
|||||||
# 打印结果
|
# 打印结果
|
||||||
print(f"Found {len(datasets)} datasets:")
|
print(f"Found {len(datasets)} datasets:")
|
||||||
for ds in datasets.all():
|
for ds in datasets.all():
|
||||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
print(f"- {ds['name']} (ID: {ds['id']})")
|
||||||
|
|
||||||
|
# 询问要保存的数据集名称
|
||||||
|
name = input("输入要保存的数据集名称(直接回车保存所有): ").strip() or None
|
||||||
|
|
||||||
|
# 保存数据集到文件
|
||||||
|
save_dataset(datasets, workdir, name)
|
||||||
|
print(f"Datasets {'all' if name is None else name} saved to json files")
|
Loading…
x
Reference in New Issue
Block a user