Compare commits
4 Commits
0a6ae7a4ee
...
67281fe06a
Author | SHA1 | Date | |
---|---|---|---|
![]() |
67281fe06a | ||
![]() |
2d905a0270 | ||
![]() |
374b124cf8 | ||
![]() |
74ae5e1426 |
@ -1,3 +1,9 @@
|
|||||||
from .init_db import get_engine, initialize_sqlite_db
|
from .init_db import get_sqlite_engine, initialize_sqlite_db
|
||||||
|
from .prompt_store import get_prompt_tinydb, initialize_prompt_store
|
||||||
|
|
||||||
__all__ = ['get_engine', 'initialize_sqlite_db']
|
__all__ = [
|
||||||
|
"get_sqlite_engine",
|
||||||
|
"initialize_sqlite_db",
|
||||||
|
"get_prompt_tinydb",
|
||||||
|
"initialize_prompt_store"
|
||||||
|
]
|
@ -1,9 +1,9 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
from sqlmodel import SQLModel, create_engine, Session
|
from sqlmodel import SQLModel, create_engine, Session
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ from schema.dataset_generation import APIProvider
|
|||||||
# 全局变量,用于存储数据库引擎实例
|
# 全局变量,用于存储数据库引擎实例
|
||||||
_engine: Optional[Engine] = None
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
def get_engine(workdir: str) -> Engine:
|
def get_sqlite_engine(workdir: str) -> Engine:
|
||||||
"""
|
"""
|
||||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||||
|
|
||||||
@ -74,6 +74,6 @@ if __name__ == "__main__":
|
|||||||
# 定义工作目录路径
|
# 定义工作目录路径
|
||||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
# 获取数据库引擎
|
# 获取数据库引擎
|
||||||
engine = get_engine(workdir)
|
engine = get_sqlite_engine(workdir)
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
initialize_sqlite_db(engine)
|
initialize_sqlite_db(engine)
|
60
db/prompt_store.py
Normal file
60
db/prompt_store.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from tinydb import TinyDB, Query
|
||||||
|
from tinydb.storages import JSONStorage
|
||||||
|
|
||||||
|
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
|
from schema.prompt import promptTempleta
|
||||||
|
|
||||||
|
# 全局变量,用于存储TinyDB实例
|
||||||
|
_db_instance: Optional[TinyDB] = None
|
||||||
|
# 自定义存储类,用于格式化JSON数据
|
||||||
|
def get_prompt_tinydb(workdir: str) -> TinyDB:
|
||||||
|
"""
|
||||||
|
获取TinyDB实例。如果实例尚未创建,则创建一个新的并返回。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TinyDB: TinyDB数据库实例
|
||||||
|
"""
|
||||||
|
global _db_instance
|
||||||
|
if not _db_instance:
|
||||||
|
# 创建数据库目录(如果不存在)
|
||||||
|
db_dir = os.path.join(workdir, "db")
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
# 定义数据库文件路径
|
||||||
|
db_path = os.path.join(db_dir, "prompts.json")
|
||||||
|
# 创建TinyDB实例
|
||||||
|
_db_instance = TinyDB(db_path)
|
||||||
|
return _db_instance
|
||||||
|
|
||||||
|
def initialize_prompt_store(db: TinyDB) -> None:
|
||||||
|
"""
|
||||||
|
初始化prompt模板存储
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db (TinyDB): TinyDB数据库实例
|
||||||
|
"""
|
||||||
|
db.insert(promptTempleta(name="default",
|
||||||
|
description="默认提示词模板",
|
||||||
|
content="""项目名为:{ project_name }
|
||||||
|
请依据以下该项目官方文档的部分内容,创造合适的对话数据集用于微调一个了解该项目的小模型的语料,要求兼顾文档中间尽可能多的信息点,使用中文
|
||||||
|
文档节选:{ content }
|
||||||
|
按照如下json格式返回:{ json }""").model_dump())
|
||||||
|
# TinyDB不需要显式创建表结构,首次使用时自动创建
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 定义工作目录路径
|
||||||
|
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||||
|
# 获取数据库实例
|
||||||
|
db = get_prompt_tinydb(workdir)
|
||||||
|
# 初始化prompt存储
|
||||||
|
initialize_prompt_store(db)
|
@ -1,7 +1,7 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import List
|
from typing import List
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from db import get_engine
|
from db import get_sqlite_engine
|
||||||
from schema import APIProvider
|
from schema import APIProvider
|
||||||
import os
|
import os
|
||||||
from global_var import sql_engine
|
from global_var import sql_engine
|
||||||
@ -30,7 +30,7 @@ def setting_page():
|
|||||||
session.add(new_provider)
|
session.add(new_provider)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(new_provider)
|
session.refresh(new_provider)
|
||||||
return get_providers()
|
return get_providers(), "", "", "" # 返回清空后的输入框值
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise gr.Error(f"添加失败: {str(e)}")
|
raise gr.Error(f"添加失败: {str(e)}")
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def setting_page():
|
|||||||
add_button.click(
|
add_button.click(
|
||||||
fn=add_provider,
|
fn=add_provider,
|
||||||
inputs=[model_id_input, base_url_input, api_key_input],
|
inputs=[model_id_input, base_url_input, api_key_input],
|
||||||
outputs=[provider_table]
|
outputs=[provider_table, model_id_input, base_url_input, api_key_input] # 添加清空输入框的输出
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_table.select(select_record, [], [], show_progress="hidden")
|
provider_table.select(select_record, [], [], show_progress="hidden")
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
from db import get_engine
|
from db import get_sqlite_engine
|
||||||
|
from db import get_prompt_tinydb
|
||||||
|
|
||||||
sql_engine = get_engine("workdir")
|
prompt_store = get_prompt_tinydb("workdir")
|
||||||
|
sql_engine = get_sqlite_engine("workdir")
|
6
main.py
6
main.py
@ -1,12 +1,12 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from frontend.setting_page import setting_page
|
from frontend.setting_page import setting_page
|
||||||
from frontend import *
|
from frontend import *
|
||||||
from db import initialize_sqlite_db
|
from db import initialize_sqlite_db,initialize_prompt_store
|
||||||
from global_var import sql_engine
|
from global_var import sql_engine,prompt_store
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
initialize_sqlite_db(sql_engine)
|
initialize_sqlite_db(sql_engine)
|
||||||
|
initialize_prompt_store(prompt_store)
|
||||||
with gr.Blocks() as app:
|
with gr.Blocks() as app:
|
||||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
|
@ -2,4 +2,5 @@ openai>=1.0.0
|
|||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
pydantic>=2.0.0
|
pydantic>=2.0.0
|
||||||
gradio>=5.0.0
|
gradio>=5.0.0
|
||||||
langchain>=0.3
|
langchain>=0.3
|
||||||
|
tinydb>=4.0.0
|
13
schema/prompt.py
Normal file
13
schema/prompt.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
class promptTempleta(BaseModel):
|
||||||
|
id: Optional[int] = Field(default=None, description="模板ID")
|
||||||
|
name: Optional[str] = Field(default="", description="模板名称")
|
||||||
|
description: Optional[str] = Field(default="", description="模板描述")
|
||||||
|
content: str = Field(default="", min_length=1, description="模板内容")
|
||||||
|
created_at: str = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||||
|
description="记录创建时间"
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user