feat(dataset): 初步完成数据集管理页面和功能
This commit is contained in:
parent
932d1e2687
commit
2c8e54bb1e
@ -3,13 +3,14 @@ import sys
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
from tinydb import TinyDB
|
from tinydb import TinyDB, Query
|
||||||
|
from tinydb.storages import MemoryStorage
|
||||||
|
|
||||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from schema.dataset import dataset, dataset_item, Q_A
|
from schema.dataset import dataset, dataset_item, Q_A
|
||||||
|
|
||||||
def get_all_dataset(workdir: str) -> List[dataset]:
|
def get_all_dataset(workdir: str) -> TinyDB:
|
||||||
"""
|
"""
|
||||||
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
扫描workdir/dataset目录下的所有json文件并读取为dataset对象列表
|
||||||
|
|
||||||
@ -17,25 +18,25 @@ def get_all_dataset(workdir: str) -> List[dataset]:
|
|||||||
workdir (str): 工作目录路径
|
workdir (str): 工作目录路径
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[dataset]: 包含所有数据集对象的列表
|
TinyDB: 包含所有数据集对象的TinyDB对象
|
||||||
"""
|
"""
|
||||||
dataset_dir = os.path.join(workdir, "dataset")
|
dataset_dir = os.path.join(workdir, "dataset")
|
||||||
if not os.path.exists(dataset_dir):
|
if not os.path.exists(dataset_dir):
|
||||||
return []
|
return TinyDB(storage=MemoryStorage)
|
||||||
|
|
||||||
datasets = []
|
db = TinyDB(storage=MemoryStorage)
|
||||||
for filename in os.listdir(dataset_dir):
|
for filename in os.listdir(dataset_dir):
|
||||||
if filename.endswith(".json"):
|
if filename.endswith(".json"):
|
||||||
filepath = os.path.join(dataset_dir, filename)
|
filepath = os.path.join(dataset_dir, filename)
|
||||||
try:
|
try:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
datasets.append(dataset(**data))
|
db.insert(data)
|
||||||
except (json.JSONDecodeError, Exception) as e:
|
except (json.JSONDecodeError, Exception) as e:
|
||||||
print(f"Error loading dataset file {filename}: {str(e)}")
|
print(f"Error loading dataset file {filename}: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return datasets
|
return db
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -45,5 +46,5 @@ if __name__ == "__main__":
|
|||||||
datasets = get_all_dataset(workdir)
|
datasets = get_all_dataset(workdir)
|
||||||
# 打印结果
|
# 打印结果
|
||||||
print(f"Found {len(datasets)} datasets:")
|
print(f"Found {len(datasets)} datasets:")
|
||||||
for ds in datasets:
|
for ds in datasets.all():
|
||||||
print(f"- {ds.name} (ID: {ds.id})")
|
print(f"- {ds['name']} (ID: {ds['id']})")
|
@ -1,9 +1,28 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from global_var import datasets
|
||||||
|
|
||||||
def dataset_manage_page():
|
def dataset_manage_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("## 数据集管理")
|
gr.Markdown("## 数据集管理")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
# 获取数据集列表并设置初始值
|
||||||
|
datasets_list = [str(ds["name"]) for ds in datasets.all()]
|
||||||
|
initial_dataset = datasets_list[0] if datasets_list else None
|
||||||
|
|
||||||
|
dataset_dropdown = gr.Dropdown(
|
||||||
|
choices=datasets_list,
|
||||||
|
value=initial_dataset, # 设置初始选中项
|
||||||
|
label="选择数据集",
|
||||||
|
allow_custom_value=True,
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
dataset_state = gr.State(value=initial_dataset) # 用数据集初始值初始化状态
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 绑定事件,确保交互时更新状态
|
||||||
|
dataset_dropdown.change(lambda x: x, inputs=dataset_dropdown, outputs=dataset_state)
|
||||||
|
|
||||||
return demo
|
return demo
|
@ -1,6 +1,7 @@
|
|||||||
from db import get_sqlite_engine,get_prompt_tinydb
|
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||||
from tools import scan_docs_directory
|
from tools import scan_docs_directory
|
||||||
|
|
||||||
prompt_store = get_prompt_tinydb("workdir")
|
prompt_store = get_prompt_tinydb("workdir")
|
||||||
sql_engine = get_sqlite_engine("workdir")
|
sql_engine = get_sqlite_engine("workdir")
|
||||||
docs = scan_docs_directory("workdir")
|
docs = scan_docs_directory("workdir")
|
||||||
|
datasets = get_all_dataset("workdir")
|
Loading…
x
Reference in New Issue
Block a user