Compare commits

..

No commits in common. "b1e98ca913ffa0c9ddd65a0709c0fc0e4c0059e1" and "519a5f37736841e630218b44db0c8ebc961a6fc1" have entirely different histories.

3 changed files with 1 additions and 104 deletions

View File

@ -1,52 +0,0 @@
from sqlmodel import SQLModel, create_engine, Session
from sqlmodel import select
from typing import Optional
import os
from pathlib import Path
import sys
from dotenv import load_dotenv
from sqlalchemy.engine import Engine
# 项目路径配置
sys.path.append(str(Path(__file__).resolve().parent.parent))
from schema.dataset_generation import APIProvider
# 全局引擎实例(可选)
_engine: Optional[Engine] = None
def get_engine(workdir: str) -> Engine:
global _engine
if not _engine:
db_dir = os.path.join(workdir, "db")
os.makedirs(db_dir, exist_ok=True)
db_path = os.path.join(db_dir, "db.sqlite")
db_url = f"sqlite:///{db_path}"
_engine = create_engine(db_url)
return _engine
def initialize_db(engine: Engine) -> None:
SQLModel.metadata.create_all(engine)
load_dotenv()
api_key = os.getenv("API_KEY")
base_url = os.getenv("BASE_URL")
model_id = os.getenv("MODEL_ID")
if api_key and base_url and model_id:
with Session(engine) as session:
# 使用新的 select() 语法查询
statement = select(APIProvider).limit(1)
existing_provider = session.exec(statement).first()
if not existing_provider:
provider = APIProvider(
base_url=base_url,
model_id=model_id,
api_key=api_key
)
session.add(provider)
session.commit()
if __name__ == "__main__":
workdir = os.path.join(os.path.dirname(__file__), "..", "workdir")
engine = get_engine(workdir)
initialize_db(engine)

View File

@ -1,4 +1,4 @@
from .dataset import *
from .dataset_generation import APIProvider, LLMResponse, LLMRequest
from .md_doc import MarkdownNode
from .prompt import promptTempleta
from .prompt_templeta import prompt_templeta

View File

@ -1,51 +0,0 @@
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import SQLModel, Relationship, Field
class APIProvider(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
base_url: str = Field(..., description="API的基础URL")
model_id: str = Field(..., description="API使用的模型ID")
api_key: Optional[str] = Field(default=None, description="用于身份验证的API密钥")
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="记录创建时间"
)
class LLMResponse(SQLModel):
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="响应的时间戳"
)
response_id: str = Field(..., description="响应的唯一ID")
tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")
response_content: dict = Field(default_factory=dict, description="API响应的内容")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
llm_parameters: dict = Field(default_factory=lambda: {
"temperature": None,
"max_tokens": None,
"top_p": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None
}, description="API的生成参数")
class LLMRequest(SQLModel):
prompt: str = Field(..., description="发送给API的提示词")
provider_id: int = Field(foreign_key="apiprovider.id")
provider: APIProvider = Relationship()
format: Optional[str] = Field(default=None, description="API响应的格式")
response: list[LLMResponse] = Field(default_factory=list, description="API响应列表")
error: Optional[list[str]] = Field(default=None, description="API请求过程中发生的错误信息")
total_duration: float = Field(default=0.0, description="请求的总时长,单位为秒")
total_tokens_usage: dict = Field(default_factory=lambda: {
"prompt_tokens": 0,
"completion_tokens": 0,
"prompt_cache_hit_tokens": None,
"prompt_cache_miss_tokens": None
}, description="token使用信息")