Compare commits
5 Commits
8fb9f785b9
...
4b465ec917
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4b465ec917 | ||
![]() |
e7cc03297b | ||
![]() |
051d1a7535 | ||
![]() |
97172f9596 | ||
![]() |
f582820443 |
5
.gitignore
vendored
5
.gitignore
vendored
@ -30,5 +30,6 @@ workdir/
|
||||
# cache
|
||||
unsloth_compiled_cache
|
||||
|
||||
# 测试代码
|
||||
test.ipynb
|
||||
# 测试和参考代码
|
||||
test.ipynb
|
||||
refer/
|
@ -1,6 +1,7 @@
|
||||
from .chat_page import *
|
||||
from .setting_page import *
|
||||
from .train_page import *
|
||||
from .model_manage_page import *
|
||||
from .dataset_manage_page import *
|
||||
from .dataset_generate_page import *
|
||||
from .prompt_manage_page import *
|
@ -1,9 +1,28 @@
|
||||
import gradio as gr
|
||||
from global_var import model,tokenizer
|
||||
|
||||
def chat_page():
|
||||
with gr.Blocks() as demo:
|
||||
import random
|
||||
import time
|
||||
gr.Markdown("## 聊天")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
pass
|
||||
chatbot = gr.Chatbot(type="messages")
|
||||
msg = gr.Textbox()
|
||||
clear = gr.Button("Clear")
|
||||
|
||||
def user(user_message, history: list):
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list):
|
||||
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
for character in bot_message:
|
||||
history[-1]['content'] += character
|
||||
time.sleep(0.1)
|
||||
yield history
|
||||
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, chatbot, chatbot
|
||||
)
|
||||
clear.click(lambda: None, None, chatbot, queue=False)
|
||||
return demo
|
@ -23,7 +23,8 @@ def dataset_manage_page():
|
||||
components=["text", "text"],
|
||||
label="问答数据",
|
||||
headers=["问题", "答案"],
|
||||
samples=[["示例问题", "示例答案"]]
|
||||
samples=[["示例问题", "示例答案"]],
|
||||
samples_per_page=20,
|
||||
)
|
||||
|
||||
def update_qa_display(dataset_name):
|
||||
|
9
frontend/model_manage_page.py
Normal file
9
frontend/model_manage_page.py
Normal file
@ -0,0 +1,9 @@
|
||||
import gradio as gr
|
||||
|
||||
def model_manage_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 模型管理")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
pass
|
||||
return demo
|
@ -4,4 +4,7 @@ from tools import scan_docs_directory
|
||||
prompt_store = get_prompt_tinydb("workdir")
|
||||
sql_engine = get_sqlite_engine("workdir")
|
||||
docs = scan_docs_directory("workdir")
|
||||
datasets = get_all_dataset("workdir")
|
||||
datasets = get_all_dataset("workdir")
|
||||
|
||||
model = None
|
||||
tokenizer = None
|
2
main.py
2
main.py
@ -10,6 +10,8 @@ if __name__ == "__main__":
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("模型管理"):
|
||||
model_manage_page()
|
||||
with gr.TabItem("模型推理"):
|
||||
chat_page()
|
||||
with gr.TabItem("模型微调"):
|
||||
|
63
tools/json_example.py
Normal file
63
tools/json_example.py
Normal file
@ -0,0 +1,63 @@
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing import Any, Dict, List, Optional, get_args, get_origin
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
|
||||
def generate_example_json(model: type[BaseModel]) -> str:
|
||||
"""
|
||||
根据 Pydantic V2 模型生成示例 JSON 数据结构。
|
||||
"""
|
||||
|
||||
def _generate_example(field_type: Any) -> Any:
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin is list or origin is List:
|
||||
if args:
|
||||
return [_generate_example(args[0])]
|
||||
else:
|
||||
return []
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2 and args[0] is str:
|
||||
return {"key": _generate_example(args[1])}
|
||||
else:
|
||||
return {}
|
||||
elif origin is Optional or origin is type(None):
|
||||
if args:
|
||||
return _generate_example(args[0])
|
||||
else:
|
||||
return None
|
||||
elif field_type is str:
|
||||
return "string"
|
||||
elif field_type is int:
|
||||
return 0
|
||||
elif field_type is float:
|
||||
return 0.0
|
||||
elif field_type is bool:
|
||||
return True
|
||||
elif field_type is datetime:
|
||||
return datetime.now().isoformat()
|
||||
elif field_type is date:
|
||||
return date.today().isoformat()
|
||||
elif issubclass(field_type, BaseModel):
|
||||
return generate_example_json(field_type)
|
||||
else:
|
||||
return "unknown" # 对于未知类型返回 "unknown"
|
||||
|
||||
example_data = {}
|
||||
for field_name, field in model.model_fields.items():
|
||||
example_data[field_name] = _generate_example(field.annotation)
|
||||
|
||||
return json.dumps(example_data, indent=2, default=str)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# 添加项目根目录到sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
from schema import Q_A
|
||||
class Q_A_list(BaseModel):
|
||||
Q_As: List[Q_A]
|
||||
|
||||
print("示例 JSON:")
|
||||
print(generate_example_json(Q_A_list))
|
Loading…
x
Reference in New Issue
Block a user