Compare commits
11 Commits
mvp
...
841e14a093
Author | SHA1 | Date | |
---|---|---|---|
![]() |
841e14a093 | ||
![]() |
2ff077bb1c | ||
![]() |
513b639bce | ||
![]() |
f93f213a31 | ||
![]() |
10b4c29bda | ||
![]() |
b1e98ca913 | ||
![]() |
2d5a5277ae | ||
![]() |
519a5f3773 | ||
![]() |
1f4d491694 | ||
![]() |
8ce4f1e373 | ||
![]() |
3395b860e4 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,6 +11,7 @@ env/
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
.roo
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
|
3
db/__init__.py
Normal file
3
db/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .init_db import get_engine, initialize_db
|
||||
|
||||
__all__ = ['get_engine', 'initialize_db']
|
79
db/init_db.py
Normal file
79
db/init_db.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from sqlmodel import SQLModel, create_engine, Session
|
||||
from sqlmodel import select
|
||||
from typing import Optional
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
# 将项目根目录添加到系统路径中,以便能够导入项目中的其他模块
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema.dataset_generation import APIProvider
|
||||
|
||||
# 全局变量,用于存储数据库引擎实例
|
||||
_engine: Optional[Engine] = None
|
||||
|
||||
def get_engine(workdir: str) -> Engine:
|
||||
"""
|
||||
获取数据库引擎实例。如果引擎尚未创建,则创建一个新的引擎并返回。
|
||||
|
||||
Args:
|
||||
workdir (str): 工作目录路径,用于确定数据库文件的存储位置。
|
||||
|
||||
Returns:
|
||||
Engine: SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
global _engine
|
||||
if not _engine:
|
||||
# 创建数据库目录(如果不存在)
|
||||
db_dir = os.path.join(workdir, "db")
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
# 定义数据库文件路径
|
||||
db_path = os.path.join(db_dir, "db.sqlite")
|
||||
# 创建数据库URL
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
# 创建数据库引擎
|
||||
_engine = create_engine(db_url)
|
||||
return _engine
|
||||
|
||||
def initialize_db(engine: Engine) -> None:
|
||||
"""
|
||||
初始化数据库,创建所有表结构,并插入初始数据(如果不存在)。
|
||||
|
||||
Args:
|
||||
engine (Engine): SQLAlchemy 数据库引擎实例。
|
||||
"""
|
||||
# 创建所有表结构
|
||||
SQLModel.metadata.create_all(engine)
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
# 从环境变量中获取API相关配置
|
||||
api_key = os.getenv("API_KEY")
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model_id = os.getenv("MODEL_ID")
|
||||
|
||||
# 如果所有必要的环境变量都存在,则插入初始数据
|
||||
if api_key and base_url and model_id:
|
||||
with Session(engine) as session:
|
||||
# 查询是否已存在APIProvider记录
|
||||
statement = select(APIProvider).limit(1)
|
||||
existing_provider = session.exec(statement).first()
|
||||
|
||||
# 如果不存在,则插入新的APIProvider记录
|
||||
if not existing_provider:
|
||||
provider = APIProvider(
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
api_key=api_key
|
||||
)
|
||||
session.add(provider)
|
||||
session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 定义工作目录路径
|
||||
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
|
||||
# 获取数据库引擎
|
||||
engine = get_engine(workdir)
|
||||
# 初始化数据库
|
||||
initialize_db(engine)
|
4
frontend/__init__.py
Normal file
4
frontend/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .chat_page import chat_page
|
||||
from .setting_page import setting_page
|
||||
from .train_page import train_page
|
||||
from .dataset_page import dataset_page
|
9
frontend/chat_page.py
Normal file
9
frontend/chat_page.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import gradio as gr
|
||||
|
||||
def chat_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 聊天")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
pass
|
||||
return demo
|
9
frontend/dataset_page.py
Normal file
9
frontend/dataset_page.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import gradio as gr
|
||||
|
||||
def dataset_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 数据集")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
pass
|
||||
return demo
|
35
frontend/setting_page.py
Normal file
35
frontend/setting_page.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import gradio as gr
|
||||
from typing import List, Dict
|
||||
from sqlmodel import Session, select
|
||||
from db import get_engine
|
||||
from schema import APIProvider
|
||||
import os
|
||||
|
||||
# 获取数据库引擎
|
||||
engine = get_engine(os.path.join(os.path.dirname(__file__), "..", "workdir"))
|
||||
|
||||
def setting_page():
|
||||
def get_providers() -> List[List[str]]:
|
||||
with Session(engine) as session:
|
||||
providers = session.exec(select(APIProvider)).all()
|
||||
return [
|
||||
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
||||
for p in providers
|
||||
]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## API Provider 管理")
|
||||
|
||||
with gr.Row():
|
||||
# API Provider列表
|
||||
with gr.Column(scale=2):
|
||||
provider_table = gr.DataFrame(
|
||||
headers=["id" , "model id", "URL", "API Key"],
|
||||
datatype=["number","str", "str", "str"],
|
||||
interactive=True,
|
||||
value=get_providers(),
|
||||
wrap=True,
|
||||
col_count=(4, "fixed")
|
||||
)
|
||||
|
||||
return demo
|
9
frontend/train_page.py
Normal file
9
frontend/train_page.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import gradio as gr
|
||||
|
||||
def train_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 微调")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
pass
|
||||
return demo
|
27
main.py
Normal file
27
main.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import gradio as gr
|
||||
from frontend.setting_page import setting_page
|
||||
from frontend import chat_page,setting_page,train_page,dataset_page
|
||||
from db import initialize_db as init_db,get_engine
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db(get_engine("workdir"))
|
||||
|
||||
pages = []
|
||||
pages.append(setting_page())
|
||||
pages.append(chat_page())
|
||||
pages.append(train_page())
|
||||
pages.append(dataset_page())
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("聊天"):
|
||||
chat_page()
|
||||
with gr.TabItem("微调"):
|
||||
train_page()
|
||||
with gr.TabItem("数据集"):
|
||||
dataset_page()
|
||||
with gr.TabItem("设置"):
|
||||
setting_page()
|
||||
|
||||
app.launch()
|
@@ -1,2 +1,4 @@
|
||||
openai>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pydantic>=2.0.0
|
||||
gradio>=3.0.0
|
4
schema/__init__.py
Normal file
4
schema/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dataset import *
|
||||
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
|
||||
from .md_doc import MarkdownNode
|
||||
from .prompt import promptTempleta
|
51
schema/dataset_generation.py
Normal file
51
schema/dataset_generation.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Relationship, Field
|
||||
|
||||
class APIProvider(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
base_url: str = Field(..., description="API的基础URL")
|
||||
model_id: str = Field(..., description="API使用的模型ID")
|
||||
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
|
||||
created_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="记录创建时间"
|
||||
)
|
||||
|
||||
class LLMResponse(SQLModel):
|
||||
timestamp: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
description="响应的时间戳"
|
||||
)
|
||||
response_id: str = Field(..., description="响应的唯一ID")
|
||||
tokens_usage: dict = Field(default_factory=lambda: {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"prompt_cache_hit_tokens": None,
|
||||
"prompt_cache_miss_tokens": None
|
||||
}, description="token使用信息")
|
||||
response_content: dict = Field(default_factory=dict, description="API响应的内容")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
llm_parameters: dict = Field(default_factory=lambda: {
|
||||
"temperature": None,
|
||||
"max_tokens": None,
|
||||
"top_p": None,
|
||||
"frequency_penalty": None,
|
||||
"presence_penalty": None,
|
||||
"seed": None
|
||||
}, description="API的生成参数")
|
||||
|
||||
class LLMRequest(SQLModel):
|
||||
prompt: str = Field(..., description="发送给API的提示词")
|
||||
provider_id: int = Field(foreign_key="apiprovider.id")
|
||||
provider: APIProvider = Relationship()
|
||||
format: Optional[str] = Field(default=None, description="API响应的格式")
|
||||
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
|
||||
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
|
||||
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
|
||||
total_tokens_usage: dict = Field(default_factory=lambda: {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"prompt_cache_hit_tokens": None,
|
||||
"prompt_cache_miss_tokens": None
|
||||
}, description="token使用信息")
|
13
schema/md_doc.py
Normal file
13
schema/md_doc.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class MarkdownNode(BaseModel):
|
||||
level: int = Field(default=0, description="节点层级")
|
||||
title: str = Field(default="Root", description="节点标题")
|
||||
content: Optional[str] = Field(default=None, description="节点内容")
|
||||
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
MarkdownNode.model_rebuild()
|
@@ -1,28 +1,24 @@
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
class MarkdownNode:
|
||||
def __init__(self, level=0, title="Root"):
|
||||
self.level = level
|
||||
self.title = title
|
||||
self.content = "" # 使用字符串存储合并后的内容
|
||||
self.children = []
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import MarkdownNode
|
||||
|
||||
def __repr__(self):
|
||||
return f"({self.level}) {self.title}"
|
||||
def add_child(parent, child):
|
||||
parent.children.append(child)
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.append(child)
|
||||
|
||||
def print_tree(self, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{self.title}")
|
||||
if self.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in self.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in self.children:
|
||||
child.print_tree(indent + 1)
|
||||
def print_tree(node, indent=0):
|
||||
prefix = "│ " * (indent - 1) + "└─ " if indent > 0 else ""
|
||||
print(f"{prefix}{node.title}")
|
||||
if node.content:
|
||||
content_prefix = "│ " * indent + "├─ [内容]"
|
||||
print(content_prefix)
|
||||
for line in node.content.split('\n'):
|
||||
print("│ " * indent + "│ " + line)
|
||||
for child in node.children:
|
||||
print_tree(child, indent + 1)
|
||||
|
||||
def parse_markdown(markdown):
|
||||
lines = markdown.split('\n')
|
||||
@@ -51,10 +47,10 @@ def parse_markdown(markdown):
|
||||
if match:
|
||||
level = len(match.group(1))
|
||||
title = match.group(2)
|
||||
node = MarkdownNode(level, title)
|
||||
node = MarkdownNode(level=level, title=title, content="", children=[])
|
||||
while stack[-1].level >= level:
|
||||
stack.pop()
|
||||
stack[-1].add_child(node)
|
||||
add_child(stack[-1], node)
|
||||
stack.append(node)
|
||||
else:
|
||||
if stack[-1].content:
|
||||
@@ -65,9 +61,9 @@ def parse_markdown(markdown):
|
||||
|
||||
if __name__=="__main__":
|
||||
# 从文件读取 Markdown 内容
|
||||
with open("example.md", "r", encoding="utf-8") as f:
|
||||
with open("workdir/example.md", "r", encoding="utf-8") as f:
|
||||
markdown = f.read()
|
||||
|
||||
# 解析 Markdown 并打印树结构
|
||||
root = parse_markdown(markdown)
|
||||
root.print_tree()
|
||||
print_tree(root)
|
||||
|
Reference in New Issue
Block a user