Compare commits

6 Commits

Author SHA1 Message Date
carry
b1e98ca913 feat(db): 初始化数据库并创建 APIProvider 表
- 新增 init_db.py 文件,实现数据库初始化和 APIProvider 表的创建
- 新增 dataset_generation.py 文件,定义 LLMRequest、LLMResponse 和 APIProvider 模型
- 在初始化数据库时,如果环境变量中存在 API_KEY、BASE_URL 和 MODEL_ID,会自动添加一条 APIProvider 记录
2025-04-06 19:59:23 +08:00
carry
2d5a5277ae refactor(schema): 更新 prompt 导入
- 将 prompt_templeta 重命名为 promptTempleta,以符合驼峰命名规范
- 优化导入语句格式
2025-04-06 19:39:43 +08:00
carry
519a5f3773 feat(frontend): 添加前端页面模块并实现基本布局
- 新增 chat_page.py、setting_page.py 和 train_page.py 文件,分别实现聊天、设置和微调页面的基本布局
- 添加 main.py 文件,集成所有页面并创建主应用
- 在 requirements.txt 中添加 gradio 依赖
2025-04-06 14:49:01 +08:00
carry
1f4d491694 build: 添加 pydantic 依赖 2025-04-05 01:00:33 +08:00
carry
8ce4f1e373 chore: 添加 .roo 到 .gitignore 文件
- 在 .gitignore 文件中添加 .roo 目录,以忽略相关文件
2025-04-05 00:59:42 +08:00
carry
3395b860e4 refactor(parse_markdown): 重构 Markdown 解析逻辑并使用 Pydantic 模型
将 MarkdownNode 类重构为使用 Pydantic 模型,提高代码的可维护性和类型安全性。同时,将解析逻辑与节点操作分离,简化代码结构。
2025-04-04 20:50:39 +08:00
11 changed files with 195 additions and 26 deletions

1
.gitignore vendored
View File

@@ -11,6 +11,7 @@ env/
# IDE
.vscode/
.idea/
.roo
# Environment files
.env

52
db/init_db.py Normal file
View File

@@ -0,0 +1,52 @@
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
import os
from pathlib import Path
import sys
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 项目路径配置
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局引擎实例(可选)
_engine: Optional[Engine] = None
def get_engine(workdir: str) -> Engine:
global _engine
if not _engine:
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
db_path = os.path.join(db_dir, "db.sqlite")
db_url = f"sqlite:///{db_path}"
_engine = create_engine(db_url)
return _engine
def initialize_db(engine: Engine) -> None:
SQLModel.metadata.create_all(engine)
load_dotenv()
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
if api_key and base_url and model_id:
with Session(engine) as session:
# 使用新的 select() 语法查询
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
engine = get_engine(workdir)
initialize_db(engine)

9
frontend/chat_page.py Normal file
View File

@@ -0,0 +1,9 @@
import gradio as gr
def chat_page():
with gr.Blocks() as demo:
gr.Markdown("## 聊天")
with gr.Row():
with gr.Column():
pass
return demo

9
frontend/setting_page.py Normal file
View File

@@ -0,0 +1,9 @@
import gradio as gr
def setting_page():
with gr.Blocks() as demo:
gr.Markdown("## 设置")
with gr.Row():
with gr.Column():
pass
return demo

9
frontend/train_page.py Normal file
View File

@@ -0,0 +1,9 @@
import gradio as gr
def train_page():
with gr.Blocks() as demo:
gr.Markdown("## 微调")
with gr.Row():
with gr.Column():
pass
return demo

23
main.py Normal file
View File

@@ -0,0 +1,23 @@
import gradio as gr
from frontend.setting_page import setting_page
from frontend.chat_page import chat_page
from frontend.train_page import train_page
def main():
setting_demo = setting_page()
chat_demo = chat_page()
train_demo = train_page()
with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs():
with gr.TabItem("微调"):
train_demo.render()
with gr.TabItem("聊天"):
chat_demo.render()
with gr.TabItem("设置"):
setting_demo.render()
app.launch()
if __name__ == "__main__":
main()

View File

@@ -1,2 +1,4 @@
openai>=1.0.0
python-dotenv>=1.0.0
pydantic>=2.0.0
gradio>=3.0.0

4
schema/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
from .md_doc import MarkdownNode
from .prompt import promptTempleta

View File

@@ -0,0 +1,51 @@
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
base_url: str = Field(..., description="API的基础URL")
model_id: str = Field(..., description="API使用的模型ID")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")

13
schema/md_doc.py Normal file
View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import List, Optional
class MarkdownNode(BaseModel):
level: int = Field(default=0, description="节点层级")
title: str = Field(default="Root", description="节点标题")
content: Optional[str] = Field(default=None, description="节点内容")
children: List['MarkdownNode'] = Field(default_factory=list, description="子节点列表")
class Config:
arbitrary_types_allowed = True
MarkdownNode.model_rebuild()

View File

@@ -1,28 +1,24 @@
import re
import sys
from pathlib import Path
class MarkdownNode:
def __init__(self, level=0, title="Root"):
self.level = level
self.title = title
self.content = "" # 使用字符串存储合并后的内容
self.children = []
# 添加项目根目录到sys.path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema import MarkdownNode
def __repr__(self):
return f"({self.level}) {self.title}"
def add_child(parent, child):
parent.children.append(child)
def add_child(self, child):
self.children.append(child)
def print_tree(self, indent=0):
def print_tree(node, indent=0):
prefix = "" * (indent - 1) + "└─ " if indent > 0 else ""
print(f"{prefix}{self.title}")
if self.content:
print(f"{prefix}{node.title}")
if node.content:
content_prefix = "" * indent + "├─ [内容]"
print(content_prefix)
for line in self.content.split('\n'):
for line in node.content.split('\n'):
print("" * indent + "" + line)
for child in self.children:
child.print_tree(indent + 1)
for child in node.children:
print_tree(child, indent + 1)
def parse_markdown(markdown):
lines = markdown.split('\n')
@@ -51,10 +47,10 @@ def parse_markdown(markdown):
if match:
level = len(match.group(1))
title = match.group(2)
node = MarkdownNode(level, title)
node = MarkdownNode(level=level, title=title, content="", children=[])
while stack[-1].level >= level:
stack.pop()
stack[-1].add_child(node)
add_child(stack[-1], node)
stack.append(node)
else:
if stack[-1].content:
@@ -65,9 +61,9 @@ def parse_markdown(markdown):
if __name__=="__main__":
# 从文件读取 Markdown 内容
with open("example.md", "r", encoding="utf-8") as f:
with open("workdir/example.md", "r", encoding="utf-8") as f:
markdown = f.read()
# 解析 Markdown 并打印树结构
root = parse_markdown(markdown)
root.print_tree()
print_tree(root)