Compare commits

..

No commits in common. "fb6157af058092369dbb559f4d4342caccbc3382" and "0fa2b51a799ed0c4fa58cf64d737ea4b85542dbf" have entirely different histories.

9 changed files with 42 additions and 149 deletions

View File

@ -2,7 +2,7 @@ import gradio as gr
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer from global_var import model,tokenizer
def chat_page(): def chat_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -17,39 +17,11 @@ def chat_page():
return "", history + [{"role": "user", "content": user_message}] return "", history + [{"role": "user", "content": user_message}]
def bot(history: list): def bot(history: list):
model = get_model() bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
tokenizer = get_tokenizer() history.append({"role": "assistant", "content": ""})
print(tokenizer) for character in bot_message:
print(model) history[-1]['content'] += character
time.sleep(0.1)
# 获取用户的最新消息
user_message = history[-1]["content"]
# 使用 tokenizer 对消息进行预处理
messages = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
# 使用 TextStreamer 进行流式生成
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# 调用模型进行推理
generated_text = ""
for new_token in model.generate(
input_ids=inputs,
streamer=text_streamer,
max_new_tokens=1024,
use_cache=False,
temperature=1.5,
min_p=0.1,
):
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
history.append({"role": "assistant", "content": generated_text})
yield history yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
@ -60,8 +32,4 @@ def chat_page():
return demo return demo
if __name__ == "__main__": if __name__ == "__main__":
from model_manage_page import model_manage_page chat_page().queue().launch()
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()

View File

@ -1,6 +1,5 @@
import gradio as gr import gradio as gr
from tools import scan_docs_directory from global_var import docs, scan_docs_directory, prompt_store
from global_var import get_docs, scan_docs_directory, get_prompt_store
def dataset_generate_page(): def dataset_generate_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
@ -8,7 +7,7 @@ def dataset_generate_page():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
# 获取文档列表并设置初始值 # 获取文档列表并设置初始值
docs_list = [str(doc.name) for doc in get_docs()] docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")]
initial_doc = docs_list[0] if docs_list else None initial_doc = docs_list[0] if docs_list else None
doc_dropdown = gr.Dropdown( doc_dropdown = gr.Dropdown(
@ -22,7 +21,7 @@ def dataset_generate_page():
with gr.Column(): with gr.Column():
# 获取模板列表并设置初始值 # 获取模板列表并设置初始值
prompts = get_prompt_store().all() prompts = prompt_store.all()
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts] prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
initial_prompt = prompt_choices[0] if prompt_choices else None initial_prompt = prompt_choices[0] if prompt_choices else None

View File

@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
from global_var import get_datasets from global_var import datasets
from tinydb import Query from tinydb import Query
def dataset_manage_page(): def dataset_manage_page():
@ -7,7 +7,7 @@ def dataset_manage_page():
gr.Markdown("## 数据集管理") gr.Markdown("## 数据集管理")
with gr.Row(): with gr.Row():
# 获取数据集列表并设置初始值 # 获取数据集列表并设置初始值
datasets_list = [str(ds["name"]) for ds in get_datasets().all()] datasets_list = [str(ds["name"]) for ds in datasets.all()]
initial_dataset = datasets_list[0] if datasets_list else None initial_dataset = datasets_list[0] if datasets_list else None
dataset_dropdown = gr.Dropdown( dataset_dropdown = gr.Dropdown(
@ -33,7 +33,7 @@ def dataset_manage_page():
# 从数据库获取数据集 # 从数据库获取数据集
Dataset = Query() Dataset = Query()
ds = get_datasets().get(Dataset.name == dataset_name) ds = datasets.get(Dataset.name == dataset_name)
if not ds: if not ds:
return {"samples": [], "__type__": "update"} return {"samples": [], "__type__": "update"}

View File

@ -6,7 +6,7 @@ from unsloth import FastLanguageModel
import torch import torch
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer, set_model, set_tokenizer from global_var import model,tokenizer
from tools.model import get_model_name from tools.model import get_model_name
def model_manage_page(): def model_manage_page():
@ -35,8 +35,9 @@ def model_manage_page():
def load_model(selected_model, max_seq_length, load_in_4bit): def load_model(selected_model, max_seq_length, load_in_4bit):
try: try:
global model, tokenizer
# 判空操作,如果模型已加载,则先卸载 # 判空操作,如果模型已加载,则先卸载
if get_model() is not None: if model is not None:
unload_model() unload_model()
model_path = os.path.join(models_dir, selected_model) model_path = os.path.join(models_dir, selected_model)
@ -45,8 +46,6 @@ def model_manage_page():
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit, load_in_4bit=load_in_4bit,
) )
set_model(model)
set_tokenizer(tokenizer)
return f"模型 {get_model_name(model)} 已加载" return f"模型 {get_model_name(model)} 已加载"
except Exception as e: except Exception as e:
return f"加载模型时出错: {str(e)}" return f"加载模型时出错: {str(e)}"
@ -55,17 +54,20 @@ def model_manage_page():
def unload_model(): def unload_model():
try: try:
global model, tokenizer
# 将模型移动到 CPU # 将模型移动到 CPU
model = get_model()
if model is not None: if model is not None:
model.cpu() model.cpu()
# 如果提供了 tokenizer也将其设置为 None
if tokenizer is not None:
tokenizer = None
# 清空 CUDA 缓存 # 清空 CUDA 缓存
torch.cuda.empty_cache() torch.cuda.empty_cache()
# 将模型和tokenizer设置为 None # 将模型设置为 None
set_model(None) model = None
set_tokenizer(None)
return "当前未加载模型" return "当前未加载模型"
except Exception as e: except Exception as e:

View File

@ -3,13 +3,13 @@ import sys
from pathlib import Path from pathlib import Path
from typing import List from typing import List
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_prompt_store from global_var import prompt_store
from schema.prompt import promptTempleta from schema.prompt import promptTempleta
def prompt_manage_page(): def prompt_manage_page():
def get_prompts() -> List[List[str]]: def get_prompts() -> List[List[str]]:
selected_row = None selected_row = None
try: try:
db = get_prompt_store() db = prompt_store
prompts = db.all() prompts = db.all()
return [ return [
[p["id"], p["name"], p["description"], p["content"]] [p["id"], p["name"], p["description"], p["content"]]
@ -20,7 +20,7 @@ def prompt_manage_page():
def add_prompt(name, description, content): def add_prompt(name, description, content):
try: try:
db = get_prompt_store() db = prompt_store
new_prompt = promptTempleta( new_prompt = promptTempleta(
name=name if name else "", name=name if name else "",
description=description if description else "", description=description if description else "",
@ -38,7 +38,7 @@ def prompt_manage_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要编辑的行") raise gr.Error("请先选择要编辑的行")
try: try:
db = get_prompt_store() db = prompt_store
db.update({ db.update({
"name": selected_row[1] if selected_row[1] else "", "name": selected_row[1] if selected_row[1] else "",
"description": selected_row[2] if selected_row[2] else "", "description": selected_row[2] if selected_row[2] else "",

View File

@ -2,13 +2,13 @@ import gradio as gr
from typing import List from typing import List
from sqlmodel import Session, select from sqlmodel import Session, select
from schema import APIProvider from schema import APIProvider
from global_var import get_sql_engine from global_var import sql_engine
def setting_page(): def setting_page():
def get_providers() -> List[List[str]]: def get_providers() -> List[List[str]]:
selected_row = None selected_row = None
try: # 添加异常处理 try: # 添加异常处理
with Session(get_sql_engine()) as session: with Session(sql_engine) as session:
providers = session.exec(select(APIProvider)).all() providers = session.exec(select(APIProvider)).all()
return [ return [
[p.id, p.model_id, p.base_url, p.api_key or ""] [p.id, p.model_id, p.base_url, p.api_key or ""]
@ -19,7 +19,7 @@ def setting_page():
def add_provider(model_id, base_url, api_key): def add_provider(model_id, base_url, api_key):
try: try:
with Session(get_sql_engine()) as session: with Session(sql_engine) as session:
new_provider = APIProvider( new_provider = APIProvider(
model_id=model_id if model_id else None, model_id=model_id if model_id else None,
base_url=base_url if base_url else None, base_url=base_url if base_url else None,
@ -37,7 +37,7 @@ def setting_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要编辑的行") raise gr.Error("请先选择要编辑的行")
try: try:
with Session(get_sql_engine()) as session: with Session(sql_engine) as session:
provider = session.get(APIProvider, selected_row[0]) provider = session.get(APIProvider, selected_row[0])
if not provider: if not provider:
raise gr.Error("找不到选中的记录") raise gr.Error("找不到选中的记录")
@ -56,7 +56,7 @@ def setting_page():
if not selected_row: if not selected_row:
raise gr.Error("请先选择要删除的行") raise gr.Error("请先选择要删除的行")
try: try:
with Session(get_sql_engine()) as session: with Session(sql_engine) as session:
provider = session.get(APIProvider, selected_row[0]) provider = session.get(APIProvider, selected_row[0])
if not provider: if not provider:
raise gr.Error("找不到选中的记录") raise gr.Error("找不到选中的记录")

View File

@ -1,60 +1,10 @@
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
from tools import scan_docs_directory from tools import scan_docs_directory
_prompt_store = None prompt_store = get_prompt_tinydb("workdir")
_sql_engine = None sql_engine = get_sqlite_engine("workdir")
_docs = None docs = scan_docs_directory("workdir")
_datasets = None datasets = get_all_dataset("workdir")
def init_global_var(workdir="workdir"): model = None
"""Initialize all global variables""" tokenizer = None
global _prompt_store, _sql_engine, _docs, _datasets
_prompt_store = get_prompt_tinydb(workdir)
_sql_engine = get_sqlite_engine(workdir)
_docs = scan_docs_directory(workdir)
_datasets = get_all_dataset(workdir)
def get_prompt_store():
return _prompt_store
def set_prompt_store(new_prompt_store):
global _prompt_store
_prompt_store = new_prompt_store
def get_sql_engine():
return _sql_engine
def set_sql_engine(new_sql_engine):
global _sql_engine
_sql_engine = new_sql_engine
def get_docs():
return _docs
def set_docs(new_docs):
global _docs
_docs = new_docs
def get_datasets():
return _datasets
def set_datasets(new_datasets):
global _datasets
_datasets = new_datasets
_model = None
_tokenizer = None
def get_model():
return _model
def set_model(new_model):
global _model
_model = new_model
def get_tokenizer():
return _tokenizer
def set_tokenizer(new_tokenizer):
global _tokenizer
_tokenizer = new_tokenizer

View File

@ -1,13 +1,12 @@
import gradio as gr import gradio as gr
from frontend.setting_page import setting_page from frontend.setting_page import setting_page
from frontend import * from frontend import *
from db import initialize_sqlite_db, initialize_prompt_store from db import initialize_sqlite_db,initialize_prompt_store
from global_var import init_global_var, get_sql_engine, get_prompt_store from global_var import sql_engine,prompt_store
if __name__ == "__main__": if __name__ == "__main__":
init_global_var() initialize_sqlite_db(sql_engine)
initialize_sqlite_db(get_sql_engine()) initialize_prompt_store(prompt_store)
initialize_prompt_store(get_prompt_store())
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架") gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
with gr.Tabs(): with gr.Tabs():

View File

@ -1,29 +1,4 @@
import os import os
from unsloth.chat_templates import get_chat_template
def formatting_prompts_func(examples,tokenizer):
"""格式化对话数据的函数
Args:
examples: 包含对话列表的字典
Returns:
包含格式化文本的字典
"""
questions = examples["question"]
answer = examples["answer"]
# 将Question和Response组合成对话形式
convos = [
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
for q, r in zip(questions, answer)
]
# 使用tokenizer.apply_chat_template格式化对话
texts = [
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
for convo in convos
]
return {"text": texts}
def get_model_name(model): def get_model_name(model):
return os.path.basename(model.name_or_path) return os.path.basename(model.name_or_path)