Compare commits

...

4 Commits

Author SHA1 Message Date
carry
fb6157af05 feat(frontend): 初步实现聊天页面的智能回复功能 2025-04-11 18:08:38 +08:00
carry
f655936741 refactor(global_var): 重构全局变量初始化方法
- 新增 init_global_var 函数,用于统一初始化所有全局变量
- 修改 get_prompt_store、get_sql_engine、get_docs 和 get_datasets 函数,使用新的全局变量初始化逻辑
- 更新 main.py 中的代码,使用新的 init_global_var 函数替代原有的单独初始化方法
2025-04-11 18:08:16 +08:00
carry
ab7897351a fix(global_var): 修复全局变量多文件多副本的不统一问题 2025-04-11 18:04:42 +08:00
carry
216bfe39ae feat(tools): 添加格式化对话数据的函数
- 新增 formatting_prompts_func 函数,用于格式化对话数据
- 该函数将问题和答案组合成对话形式,并使用 tokenizer.apply_chat_template 进行格式化
- 更新 imports,添加了 unsloth.chat_templates 模块
2025-04-11 17:56:46 +08:00
9 changed files with 149 additions and 42 deletions

View File

@ -2,7 +2,7 @@ import gradio as gr
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import model,tokenizer from global_var import get_model, get_tokenizer
def chat_page(): def chat_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -17,11 +17,39 @@ def chat_page():
return "", history + [{"role": "user", "content": user_message}] return "", history + [{"role": "user", "content": user_message}]
def bot(history: list): def bot(history: list):
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"]) model = get_model()
history.append({"role": "assistant", "content": ""}) tokenizer = get_tokenizer()
for character in bot_message: print(tokenizer)
history[-1]['content'] += character print(model)
time.sleep(0.1)
# 获取用户的最新消息
user_message = history[-1]["content"]
# 使用 tokenizer 对消息进行预处理
messages = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
# 使用 TextStreamer 进行流式生成
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# 调用模型进行推理
generated_text = ""
for new_token in model.generate(
input_ids=inputs,
streamer=text_streamer,
max_new_tokens=1024,
use_cache=False,
temperature=1.5,
min_p=0.1,
):
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
history.append({"role": "assistant", "content": generated_text})
yield history yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
@ -32,4 +60,8 @@ def chat_page():
return demo return demo
if __name__ == "__main__": if __name__ == "__main__":
chat_page().queue().launch() from model_manage_page import model_manage_page
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

View File

@ -1,5 +1,6 @@
import gradio as gr import gradio as gr
from global_var import docs, scan_docs_directory, prompt_store from tools import scan_docs_directory
from global_var import get_docs, scan_docs_directory, get_prompt_store
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -7,7 +8,7 @@ def dataset_generate_page():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
# 获取文档列表并设置初始值 # 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")] docs_list = [str(doc.name) for doc in get_docs()]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown( doc_dropdown = gr.Dropdown(
@ -21,7 +22,7 @@ def dataset_generate_page():
with gr.Column(): with gr.Column():
# 获取模板列表并设置初始值 # 获取模板列表并设置初始值
prompts = prompt_store.all() prompts = get_prompt_store().all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts] prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None initial_prompt = prompt_choices[0] if prompt_choices else None

View File

@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
from global_var import datasets from global_var import get_datasets
from tinydb import Query from tinydb import Query
def dataset_manage_page(): def dataset_manage_page():
@ -7,7 +7,7 @@ def dataset_manage_page():
gr.Markdown("## 数据集管理") gr.Markdown("## 数据集管理")
with gr.Row(): with gr.Row():
# 获取数据集列表并设置初始值 # 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in datasets.all()] datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
initial_dataset = datasets_list[0] if datasets_list else None initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown( dataset_dropdown = gr.Dropdown(
@ -33,7 +33,7 @@ def dataset_manage_page():
# 从数据库获取数据集 # 从数据库获取数据集
Dataset = Query() Dataset = Query()
ds = datasets.get(Dataset.name == dataset_name) ds = get_datasets().get(Dataset.name == dataset_name)
if not ds: if not ds:
return {"samples": [], "__type__": "update"} return {"samples": [], "__type__": "update"}

View File

@ -6,7 +6,7 @@ from unsloth import FastLanguageModel
import torch import torch
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import model,tokenizer from global_var import get_model, get_tokenizer, set_model, set_tokenizer
from tools.model import get_model_name from tools.model import get_model_name
def model_manage_page(): def model_manage_page():
@ -35,9 +35,8 @@ def model_manage_page():
def load_model(selected_model, max_seq_length, load_in_4bit): def load_model(selected_model, max_seq_length, load_in_4bit):
try: try:
global model, tokenizer
# 判空操作,如果模型已加载,则先卸载 # 判空操作,如果模型已加载,则先卸载
if model is not None: if get_model() is not None:
unload_model() unload_model()
model_path = os.path.join(models_dir, selected_model) model_path = os.path.join(models_dir, selected_model)
@ -46,6 +45,8 @@ def model_manage_page():
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit, load_in_4bit=load_in_4bit,
) )
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载" return f"模型 {get_model_name(model)} 已加载"
except Exception as e: except Exception as e:
return f"加载模型时出错: {str(e)}" return f"加载模型时出错: {str(e)}"
@ -54,20 +55,17 @@ def model_manage_page():
def unload_model(): def unload_model():
try: try:
global model, tokenizer
# 将模型移动到 CPU # 将模型移动到 CPU
model = get_model()
if model is not None: if model is not None:
model.cpu() model.cpu()
# 如果提供了 tokenizer也将其设置为 None
if tokenizer is not None:
tokenizer = None
# 清空 CUDA 缓存 # 清空 CUDA 缓存
torch.cuda.empty_cache() torch.cuda.empty_cache()
# 将模型设置为 None # 将模型和tokenizer设置为 None
model = None set_model(None)
set_tokenizer(None)
return "当前未加载模型" return "当前未加载模型"
except Exception as e: except Exception as e:

View File

@ -3,13 +3,13 @@ import sys
from pathlib import Path from pathlib import Path
from typing import List from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import prompt_store from global_var import get_prompt_store
from schema.prompt import promptTempleta from schema.prompt import promptTempleta
def prompt_manage_page(): def prompt_manage_page():
def get_prompts() -> List[List[str]]: def get_prompts() -> List[List[str]]:
selected_row = None selected_row = None
try: try:
db = prompt_store db = get_prompt_store()
prompts = db.all() prompts = db.all()
return [ return [
[p["id"], p["name"], p["description"], p["content"]] [p["id"], p["name"], p["description"], p["content"]]
@ -20,7 +20,7 @@ def prompt_manage_page():
def add_prompt(name, description, content): def add_prompt(name, description, content):
try: try:
db = prompt_store db = get_prompt_store()
new_prompt = promptTempleta( new_prompt = promptTempleta(
name=name if name else "", name=name if name else "",
description=description if description else "", description=description if description else "",
@ -38,7 +38,7 @@ def prompt_manage_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要编辑的行") raise gr.Error("请先选择要编辑的行")
try: try:
db = prompt_store db = get_prompt_store()
db.update({ db.update({
"name": selected_row[1] if selected_row[1] else "", "name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "", "description": selected_row[2] if selected_row[2] else "",

View File

@ -2,13 +2,13 @@ import gradio as gr
from typing import List from typing import List
from sqlmodel import Session, select from sqlmodel import Session, select
from schema import APIProvider from schema import APIProvider
from global_var import sql_engine from global_var import get_sql_engine
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]: def get_providers() -> List[List[str]]:
selected_row = None selected_row = None
try: # 添加异常处理 try: # 添加异常处理
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
providers = session.exec(select(APIProvider)).all() providers = session.exec(select(APIProvider)).all()
return [ return [
[p.id, p.model_id, p.base_url, p.api_key or ""] [p.id, p.model_id, p.base_url, p.api_key or ""]
@ -19,7 +19,7 @@ def setting_page():
def add_provider(model_id, base_url, api_key): def add_provider(model_id, base_url, api_key):
try: try:
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
new_provider = APIProvider( new_provider = APIProvider(
model_id=model_id if model_id else None, model_id=model_id if model_id else None,
base_url=base_url if base_url else None, base_url=base_url if base_url else None,
@ -37,7 +37,7 @@ def setting_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要编辑的行") raise gr.Error("请先选择要编辑的行")
try: try:
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0]) provider = session.get(APIProvider, selected_row[0])
if not provider: if not provider:
raise gr.Error("找不到选中的记录") raise gr.Error("找不到选中的记录")
@ -56,7 +56,7 @@ def setting_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要删除的行") raise gr.Error("请先选择要删除的行")
try: try:
with Session(sql_engine) as session: with Session(get_sql_engine()) as session:
provider = session.get(APIProvider, selected_row[0]) provider = session.get(APIProvider, selected_row[0])
if not provider: if not provider:
raise gr.Error("找不到选中的记录") raise gr.Error("找不到选中的记录")

View File

@ -1,10 +1,60 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
prompt_store = get_prompt_tinydb("workdir") _prompt_store = None
sql_engine = get_sqlite_engine("workdir") _sql_engine = None
docs = scan_docs_directory("workdir") _docs = None
datasets = get_all_dataset("workdir") _datasets = None
model = None def init_global_var(workdir="workdir"):
tokenizer = None """Initialize all global variables"""
global _prompt_store, _sql_engine, _docs, _datasets
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)
_datasets = get_all_dataset(workdir)
def get_prompt_store():
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
return _docs
def set_docs(new_docs):
global _docs
_docs = new_docs
def get_datasets():
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
_model = None
_tokenizer = None
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer

View File

@ -2,11 +2,12 @@ import gradio as gr
from frontend.setting_page import setting_page from frontend.setting_page import setting_page
from frontend import * from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store from db import initialize_sqlite_db, initialize_prompt_store
from global_var import sql_engine,prompt_store from global_var import init_global_var, get_sql_engine, get_prompt_store
if __name__ == "__main__": if __name__ == "__main__":
initialize_sqlite_db(sql_engine) init_global_var()
initialize_prompt_store(prompt_store) initialize_sqlite_db(get_sql_engine())
initialize_prompt_store(get_prompt_store())
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs(): with gr.Tabs():

View File

@ -1,4 +1,29 @@
import os import os
from unsloth.chat_templates import get_chat_template
def formatting_prompts_func(examples,tokenizer):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
# 将Question和Response组合成对话形式
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
def get_model_name(model): def get_model_name(model):
return os.path.basename(model.name_or_path) return os.path.basename(model.name_or_path)