Compare commits
3 Commits
202d4c44df
...
8fb9f785b9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8fb9f785b9 | ||
![]() |
2c8e54bb1e | ||
![]() |
932d1e2687 |
@ -3,13 +3,14 @@ import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tinydb import TinyDB
|
||||
from tinydb import TinyDB, Query
|
||||
from tinydb.storages import MemoryStorage
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset import dataset, dataset_item, Q_A
|
||||
|
||||
def get_all_dataset(workdir: str) -> List[dataset]:
|
||||
def get_all_dataset(workdir: str) -> TinyDB:
|
||||
"""
|
||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||||
|
||||
@ -17,25 +18,25 @@ def get_all_dataset(workdir: str) -> List[dataset]:
|
||||
workdir (str): 工作目录路径
|
||||
|
||||
Returns:
|
||||
List[dataset]: 包含所有数据集对象的列表
|
||||
TinyDB: 包含所有数据集对象的TinyDB对象
|
||||
"""
|
||||
dataset_dir = os.path.join(workdir, "dataset")
|
||||
if not os.path.exists(dataset_dir):
|
||||
return []
|
||||
return TinyDB(storage=MemoryStorage)
|
||||
|
||||
datasets = []
|
||||
db = TinyDB(storage=MemoryStorage)
|
||||
for filename in os.listdir(dataset_dir):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(dataset_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
datasets.append(dataset(**data))
|
||||
db.insert(data)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
||||
continue
|
||||
|
||||
return datasets
|
||||
return db
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -45,5 +46,5 @@ if __name__ == "__main__":
|
||||
datasets = get_all_dataset(workdir)
|
||||
# 打印结果
|
||||
print(f"Found {len(datasets)} datasets:")
|
||||
for ds in datasets:
|
||||
print(f"- {ds.name} (ID: {ds.id})")
|
||||
for ds in datasets.all():
|
||||
print(f"- {ds['name']} (ID: {ds['id']})")
|
@ -1,9 +1,54 @@
|
||||
import gradio as gr
|
||||
from global_var import datasets
|
||||
from tinydb import Query
|
||||
|
||||
def dataset_manage_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集管理")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
pass
|
||||
# 获取数据集列表并设置初始值
|
||||
datasets_list = [str(ds["name"]) for ds in datasets.all()]
|
||||
initial_dataset = datasets_list[0] if datasets_list else None
|
||||
|
||||
dataset_dropdown = gr.Dropdown(
|
||||
choices=datasets_list,
|
||||
value=initial_dataset, # 设置初始选中项
|
||||
label="选择数据集",
|
||||
allow_custom_value=True,
|
||||
interactive=True
|
||||
)
|
||||
|
||||
# 添加数据集展示组件
|
||||
qa_dataset = gr.Dataset(
|
||||
components=["text", "text"],
|
||||
label="问答数据",
|
||||
headers=["问题", "答案"],
|
||||
samples=[["示例问题", "示例答案"]]
|
||||
)
|
||||
|
||||
def update_qa_display(dataset_name):
|
||||
if not dataset_name:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 从数据库获取数据集
|
||||
Dataset = Query()
|
||||
ds = datasets.get(Dataset.name == dataset_name)
|
||||
if not ds:
|
||||
return {"samples": [], "__type__": "update"}
|
||||
|
||||
# 提取所有Q_A数据
|
||||
qa_list = []
|
||||
for item in ds["dataset_items"]:
|
||||
for qa in item["message"]:
|
||||
qa_list.append([qa["question"], qa["answer"]])
|
||||
|
||||
return {"samples": qa_list, "__type__": "update"}
|
||||
|
||||
# 绑定事件,更新QA数据显示
|
||||
dataset_dropdown.change(
|
||||
update_qa_display,
|
||||
inputs=dataset_dropdown,
|
||||
outputs=qa_dataset
|
||||
)
|
||||
|
||||
return demo
|
@ -1,6 +1,7 @@
|
||||
from db import get_sqlite_engine,get_prompt_tinydb
|
||||
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||
from tools import scan_docs_directory
|
||||
|
||||
prompt_store = get_prompt_tinydb("workdir")
|
||||
sql_engine = get_sqlite_engine("workdir")
|
||||
docs = scan_docs_directory("workdir")
|
||||
docs = scan_docs_directory("workdir")
|
||||
datasets = get_all_dataset("workdir")
|
@ -19,7 +19,7 @@ class dataset_item(BaseModel):
|
||||
|
||||
class dataset(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="数据集ID")
|
||||
name: Optional[str] = Field(default=None, description="数据集名称")
|
||||
name: str = Field(default="", description="数据集名称")
|
||||
model_id: Optional[list[str]] = Field(default=None, description="数据集使用的模型ID")
|
||||
source_doc: Optional[doc] = Field(default=None, description="数据集来源文档")
|
||||
description: Optional[str] = Field(default="", description="数据集描述")
|
||||
|
Loading…
x
Reference in New Issue
Block a user