Compare commits
No commits in common. "202d4c44df4735bf9aa353297996cd398e06f568" and "41447c5ed45f9ab1c9de8d0dc653b4f1db34ead5" have entirely different histories.
202d4c44df
...
41447c5ed4
@ -1,11 +1,9 @@
|
|||||||
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
||||||
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||||
from .dataset_store import get_all_dataset
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_sqlite_engine",
|
"get_sqlite_engine",
|
||||||
"initialize_sqlite_db",
|
"initialize_sqlite_db",
|
||||||
"get_prompt_tinydb",
|
"get_prompt_tinydb",
|
||||||
"initialize_prompt_store",
|
"initialize_prompt_store"
|
||||||
"get_all_dataset"
|
|
||||||
]
|
]
|
@ -1,49 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
from tinydb import TinyDB
|
|
||||||
|
|
||||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
||||||
from schema.dataset import dataset, dataset_item, Q_A
|
|
||||||
|
|
||||||
def get_all_dataset(workdir: str) -> List[dataset]:
|
|
||||||
"""
|
|
||||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workdir (str): 工作目录路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dataset]: 包含所有数据集对象的列表
|
|
||||||
"""
|
|
||||||
dataset_dir = os.path.join(workdir, "dataset")
|
|
||||||
if not os.path.exists(dataset_dir):
|
|
||||||
return []
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
for filename in os.listdir(dataset_dir):
|
|
||||||
if filename.endswith(".json"):
|
|
||||||
filepath = os.path.join(dataset_dir, filename)
|
|
||||||
try:
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
datasets.append(dataset(**data))
|
|
||||||
except (json.JSONDecodeError, Exception) as e:
|
|
||||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 定义工作目录路径
|
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
|
||||||
# 获取所有数据集
|
|
||||||
datasets = get_all_dataset(workdir)
|
|
||||||
# 打印结果
|
|
||||||
print(f"Found {len(datasets)} datasets:")
|
|
||||||
for ds in datasets:
|
|
||||||
print(f"- {ds.name} (ID: {ds.id})")
|
|
@ -7,7 +7,6 @@ class doc(BaseModel):
|
|||||||
name: str = Field(default="", description="文档名称")
|
name: str = Field(default="", description="文档名称")
|
||||||
path: str = Field(default="", description="文档路径")
|
path: str = Field(default="", description="文档路径")
|
||||||
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
|
markdown_files: list[str] = Field(default_factory=list, description="文档路径列表")
|
||||||
version: Optional[str] = Field(default="", description="文档版本")
|
|
||||||
|
|
||||||
class Q_A(BaseModel):
|
class Q_A(BaseModel):
|
||||||
question: str = Field(default="", min_length=1,description="问题")
|
question: str = Field(default="", min_length=1,description="问题")
|
||||||
@ -21,7 +20,7 @@ class dataset(BaseModel):
|
|||||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||||
name: Optional[str] = Field(default=None, description="数据集名称")
|
name: Optional[str] = Field(default=None, description="数据集名称")
|
||||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||||
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
|
source_doc: Optional[list[doc]] = Field(default=None, description="数据集来源文档")
|
||||||
description: Optional[str] = Field(default="", description="数据集描述")
|
description: Optional[str] = Field(default="", description="数据集描述")
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user