Compare commits
4 Commits
0fa2b51a79
...
fb6157af05
Author | SHA1 | Date | |
---|---|---|---|
![]() |
fb6157af05 | ||
![]() |
f655936741 | ||
![]() |
ab7897351a | ||
![]() |
216bfe39ae |
@ -2,7 +2,7 @@ import gradio as gr
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import model,tokenizer
|
from global_var import get_model, get_tokenizer
|
||||||
|
|
||||||
def chat_page():
|
def chat_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@ -17,11 +17,39 @@ def chat_page():
|
|||||||
return "", history + [{"role": "user", "content": user_message}]
|
return "", history + [{"role": "user", "content": user_message}]
|
||||||
|
|
||||||
def bot(history: list):
|
def bot(history: list):
|
||||||
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
|
model = get_model()
|
||||||
history.append({"role": "assistant", "content": ""})
|
tokenizer = get_tokenizer()
|
||||||
for character in bot_message:
|
print(tokenizer)
|
||||||
history[-1]['content'] += character
|
print(model)
|
||||||
time.sleep(0.1)
|
|
||||||
|
# 获取用户的最新消息
|
||||||
|
user_message = history[-1]["content"]
|
||||||
|
|
||||||
|
# 使用 tokenizer 对消息进行预处理
|
||||||
|
messages = [{"role": "user", "content": user_message}]
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
# 使用 TextStreamer 进行流式生成
|
||||||
|
from transformers import TextStreamer
|
||||||
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
|
||||||
|
|
||||||
|
# 调用模型进行推理
|
||||||
|
generated_text = ""
|
||||||
|
for new_token in model.generate(
|
||||||
|
input_ids=inputs,
|
||||||
|
streamer=text_streamer,
|
||||||
|
max_new_tokens=1024,
|
||||||
|
use_cache=False,
|
||||||
|
temperature=1.5,
|
||||||
|
min_p=0.1,
|
||||||
|
):
|
||||||
|
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
|
||||||
|
history.append({"role": "assistant", "content": generated_text})
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||||
@ -32,4 +60,8 @@ def chat_page():
|
|||||||
return demo
|
return demo
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_page().queue().launch()
|
from model_manage_page import model_manage_page
|
||||||
|
# 装载两个页面
|
||||||
|
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
|
||||||
|
demo.queue()
|
||||||
|
demo.launch()
|
@ -1,5 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from global_var import docs, scan_docs_directory, prompt_store
|
from tools import scan_docs_directory
|
||||||
|
from global_var import get_docs, scan_docs_directory, get_prompt_store
|
||||||
|
|
||||||
def dataset_generate_page():
|
def dataset_generate_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@ -7,7 +8,7 @@ def dataset_generate_page():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
# 获取文档列表并设置初始值
|
# 获取文档列表并设置初始值
|
||||||
docs_list = [str(doc.name) for doc in scan_docs_directory("workdir")]
|
docs_list = [str(doc.name) for doc in get_docs()]
|
||||||
initial_doc = docs_list[0] if docs_list else None
|
initial_doc = docs_list[0] if docs_list else None
|
||||||
|
|
||||||
doc_dropdown = gr.Dropdown(
|
doc_dropdown = gr.Dropdown(
|
||||||
@ -21,7 +22,7 @@ def dataset_generate_page():
|
|||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
# 获取模板列表并设置初始值
|
# 获取模板列表并设置初始值
|
||||||
prompts = prompt_store.all()
|
prompts = get_prompt_store().all()
|
||||||
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
|
prompt_choices = [f"{p['id']} {p['name']}" for p in prompts]
|
||||||
initial_prompt = prompt_choices[0] if prompt_choices else None
|
initial_prompt = prompt_choices[0] if prompt_choices else None
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from global_var import datasets
|
from global_var import get_datasets
|
||||||
from tinydb import Query
|
from tinydb import Query
|
||||||
|
|
||||||
def dataset_manage_page():
|
def dataset_manage_page():
|
||||||
@ -7,7 +7,7 @@ def dataset_manage_page():
|
|||||||
gr.Markdown("## 数据集管理")
|
gr.Markdown("## 数据集管理")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# 获取数据集列表并设置初始值
|
# 获取数据集列表并设置初始值
|
||||||
datasets_list = [str(ds["name"]) for ds in datasets.all()]
|
datasets_list = [str(ds["name"]) for ds in get_datasets().all()]
|
||||||
initial_dataset = datasets_list[0] if datasets_list else None
|
initial_dataset = datasets_list[0] if datasets_list else None
|
||||||
|
|
||||||
dataset_dropdown = gr.Dropdown(
|
dataset_dropdown = gr.Dropdown(
|
||||||
@ -33,7 +33,7 @@ def dataset_manage_page():
|
|||||||
|
|
||||||
# 从数据库获取数据集
|
# 从数据库获取数据集
|
||||||
Dataset = Query()
|
Dataset = Query()
|
||||||
ds = datasets.get(Dataset.name == dataset_name)
|
ds = get_datasets().get(Dataset.name == dataset_name)
|
||||||
if not ds:
|
if not ds:
|
||||||
return {"samples": [], "__type__": "update"}
|
return {"samples": [], "__type__": "update"}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from unsloth import FastLanguageModel
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import model,tokenizer
|
from global_var import get_model, get_tokenizer, set_model, set_tokenizer
|
||||||
from tools.model import get_model_name
|
from tools.model import get_model_name
|
||||||
|
|
||||||
def model_manage_page():
|
def model_manage_page():
|
||||||
@ -35,9 +35,8 @@ def model_manage_page():
|
|||||||
|
|
||||||
def load_model(selected_model, max_seq_length, load_in_4bit):
|
def load_model(selected_model, max_seq_length, load_in_4bit):
|
||||||
try:
|
try:
|
||||||
global model, tokenizer
|
|
||||||
# 判空操作,如果模型已加载,则先卸载
|
# 判空操作,如果模型已加载,则先卸载
|
||||||
if model is not None:
|
if get_model() is not None:
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
model_path = os.path.join(models_dir, selected_model)
|
model_path = os.path.join(models_dir, selected_model)
|
||||||
@ -46,6 +45,8 @@ def model_manage_page():
|
|||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
load_in_4bit=load_in_4bit,
|
load_in_4bit=load_in_4bit,
|
||||||
)
|
)
|
||||||
|
set_model(model)
|
||||||
|
set_tokenizer(tokenizer)
|
||||||
return f"模型 {get_model_name(model)} 已加载"
|
return f"模型 {get_model_name(model)} 已加载"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"加载模型时出错: {str(e)}"
|
return f"加载模型时出错: {str(e)}"
|
||||||
@ -54,20 +55,17 @@ def model_manage_page():
|
|||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
try:
|
try:
|
||||||
global model, tokenizer
|
|
||||||
# 将模型移动到 CPU
|
# 将模型移动到 CPU
|
||||||
|
model = get_model()
|
||||||
if model is not None:
|
if model is not None:
|
||||||
model.cpu()
|
model.cpu()
|
||||||
|
|
||||||
# 如果提供了 tokenizer,也将其设置为 None
|
|
||||||
if tokenizer is not None:
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
# 清空 CUDA 缓存
|
# 清空 CUDA 缓存
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# 将模型设置为 None
|
# 将模型和tokenizer设置为 None
|
||||||
model = None
|
set_model(None)
|
||||||
|
set_tokenizer(None)
|
||||||
|
|
||||||
return "当前未加载模型"
|
return "当前未加载模型"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3,13 +3,13 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import prompt_store
|
from global_var import get_prompt_store
|
||||||
from schema.prompt import promptTempleta
|
from schema.prompt import promptTempleta
|
||||||
def prompt_manage_page():
|
def prompt_manage_page():
|
||||||
def get_prompts() -> List[List[str]]:
|
def get_prompts() -> List[List[str]]:
|
||||||
selected_row = None
|
selected_row = None
|
||||||
try:
|
try:
|
||||||
db = prompt_store
|
db = get_prompt_store()
|
||||||
prompts = db.all()
|
prompts = db.all()
|
||||||
return [
|
return [
|
||||||
[p["id"], p["name"], p["description"], p["content"]]
|
[p["id"], p["name"], p["description"], p["content"]]
|
||||||
@ -20,7 +20,7 @@ def prompt_manage_page():
|
|||||||
|
|
||||||
def add_prompt(name, description, content):
|
def add_prompt(name, description, content):
|
||||||
try:
|
try:
|
||||||
db = prompt_store
|
db = get_prompt_store()
|
||||||
new_prompt = promptTempleta(
|
new_prompt = promptTempleta(
|
||||||
name=name if name else "",
|
name=name if name else "",
|
||||||
description=description if description else "",
|
description=description if description else "",
|
||||||
@ -38,7 +38,7 @@ def prompt_manage_page():
|
|||||||
if not selected_row:
|
if not selected_row:
|
||||||
raise gr.Error("请先选择要编辑的行")
|
raise gr.Error("请先选择要编辑的行")
|
||||||
try:
|
try:
|
||||||
db = prompt_store
|
db = get_prompt_store()
|
||||||
db.update({
|
db.update({
|
||||||
"name": selected_row[1] if selected_row[1] else "",
|
"name": selected_row[1] if selected_row[1] else "",
|
||||||
"description": selected_row[2] if selected_row[2] else "",
|
"description": selected_row[2] if selected_row[2] else "",
|
||||||
|
@ -2,13 +2,13 @@ import gradio as gr
|
|||||||
from typing import List
|
from typing import List
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from schema import APIProvider
|
from schema import APIProvider
|
||||||
from global_var import sql_engine
|
from global_var import get_sql_engine
|
||||||
|
|
||||||
def setting_page():
|
def setting_page():
|
||||||
def get_providers() -> List[List[str]]:
|
def get_providers() -> List[List[str]]:
|
||||||
selected_row = None
|
selected_row = None
|
||||||
try: # 添加异常处理
|
try: # 添加异常处理
|
||||||
with Session(sql_engine) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
providers = session.exec(select(APIProvider)).all()
|
providers = session.exec(select(APIProvider)).all()
|
||||||
return [
|
return [
|
||||||
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
[p.id, p.model_id, p.base_url, p.api_key or ""]
|
||||||
@ -19,7 +19,7 @@ def setting_page():
|
|||||||
|
|
||||||
def add_provider(model_id, base_url, api_key):
|
def add_provider(model_id, base_url, api_key):
|
||||||
try:
|
try:
|
||||||
with Session(sql_engine) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
new_provider = APIProvider(
|
new_provider = APIProvider(
|
||||||
model_id=model_id if model_id else None,
|
model_id=model_id if model_id else None,
|
||||||
base_url=base_url if base_url else None,
|
base_url=base_url if base_url else None,
|
||||||
@ -37,7 +37,7 @@ def setting_page():
|
|||||||
if not selected_row:
|
if not selected_row:
|
||||||
raise gr.Error("请先选择要编辑的行")
|
raise gr.Error("请先选择要编辑的行")
|
||||||
try:
|
try:
|
||||||
with Session(sql_engine) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
provider = session.get(APIProvider, selected_row[0])
|
provider = session.get(APIProvider, selected_row[0])
|
||||||
if not provider:
|
if not provider:
|
||||||
raise gr.Error("找不到选中的记录")
|
raise gr.Error("找不到选中的记录")
|
||||||
@ -56,7 +56,7 @@ def setting_page():
|
|||||||
if not selected_row:
|
if not selected_row:
|
||||||
raise gr.Error("请先选择要删除的行")
|
raise gr.Error("请先选择要删除的行")
|
||||||
try:
|
try:
|
||||||
with Session(sql_engine) as session:
|
with Session(get_sql_engine()) as session:
|
||||||
provider = session.get(APIProvider, selected_row[0])
|
provider = session.get(APIProvider, selected_row[0])
|
||||||
if not provider:
|
if not provider:
|
||||||
raise gr.Error("找不到选中的记录")
|
raise gr.Error("找不到选中的记录")
|
||||||
|
@ -1,10 +1,60 @@
|
|||||||
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
from db import get_sqlite_engine, get_prompt_tinydb, get_all_dataset
|
||||||
from tools import scan_docs_directory
|
from tools import scan_docs_directory
|
||||||
|
|
||||||
prompt_store = get_prompt_tinydb("workdir")
|
_prompt_store = None
|
||||||
sql_engine = get_sqlite_engine("workdir")
|
_sql_engine = None
|
||||||
docs = scan_docs_directory("workdir")
|
_docs = None
|
||||||
datasets = get_all_dataset("workdir")
|
_datasets = None
|
||||||
|
|
||||||
model = None
|
def init_global_var(workdir="workdir"):
|
||||||
tokenizer = None
|
"""Initialize all global variables"""
|
||||||
|
global _prompt_store, _sql_engine, _docs, _datasets
|
||||||
|
_prompt_store = get_prompt_tinydb(workdir)
|
||||||
|
_sql_engine = get_sqlite_engine(workdir)
|
||||||
|
_docs = scan_docs_directory(workdir)
|
||||||
|
_datasets = get_all_dataset(workdir)
|
||||||
|
|
||||||
|
def get_prompt_store():
|
||||||
|
return _prompt_store
|
||||||
|
|
||||||
|
def set_prompt_store(new_prompt_store):
|
||||||
|
global _prompt_store
|
||||||
|
_prompt_store = new_prompt_store
|
||||||
|
|
||||||
|
def get_sql_engine():
|
||||||
|
return _sql_engine
|
||||||
|
|
||||||
|
def set_sql_engine(new_sql_engine):
|
||||||
|
global _sql_engine
|
||||||
|
_sql_engine = new_sql_engine
|
||||||
|
|
||||||
|
def get_docs():
|
||||||
|
return _docs
|
||||||
|
|
||||||
|
def set_docs(new_docs):
|
||||||
|
global _docs
|
||||||
|
_docs = new_docs
|
||||||
|
|
||||||
|
def get_datasets():
|
||||||
|
return _datasets
|
||||||
|
|
||||||
|
def set_datasets(new_datasets):
|
||||||
|
global _datasets
|
||||||
|
_datasets = new_datasets
|
||||||
|
|
||||||
|
_model = None
|
||||||
|
_tokenizer = None
|
||||||
|
|
||||||
|
def get_model():
|
||||||
|
return _model
|
||||||
|
|
||||||
|
def set_model(new_model):
|
||||||
|
global _model
|
||||||
|
_model = new_model
|
||||||
|
|
||||||
|
def get_tokenizer():
|
||||||
|
return _tokenizer
|
||||||
|
|
||||||
|
def set_tokenizer(new_tokenizer):
|
||||||
|
global _tokenizer
|
||||||
|
_tokenizer = new_tokenizer
|
9
main.py
9
main.py
@ -1,12 +1,13 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from frontend.setting_page import setting_page
|
from frontend.setting_page import setting_page
|
||||||
from frontend import *
|
from frontend import *
|
||||||
from db import initialize_sqlite_db,initialize_prompt_store
|
from db import initialize_sqlite_db, initialize_prompt_store
|
||||||
from global_var import sql_engine,prompt_store
|
from global_var import init_global_var, get_sql_engine, get_prompt_store
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
initialize_sqlite_db(sql_engine)
|
init_global_var()
|
||||||
initialize_prompt_store(prompt_store)
|
initialize_sqlite_db(get_sql_engine())
|
||||||
|
initialize_prompt_store(get_prompt_store())
|
||||||
with gr.Blocks() as app:
|
with gr.Blocks() as app:
|
||||||
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
gr.Markdown("# 基于文档驱动的自适应编码大模型微调框架")
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
|
@ -1,4 +1,29 @@
|
|||||||
import os
|
import os
|
||||||
|
from unsloth.chat_templates import get_chat_template
|
||||||
|
|
||||||
|
def formatting_prompts_func(examples,tokenizer):
|
||||||
|
"""格式化对话数据的函数
|
||||||
|
Args:
|
||||||
|
examples: 包含对话列表的字典
|
||||||
|
Returns:
|
||||||
|
包含格式化文本的字典
|
||||||
|
"""
|
||||||
|
questions = examples["question"]
|
||||||
|
answer = examples["answer"]
|
||||||
|
|
||||||
|
# 将Question和Response组合成对话形式
|
||||||
|
convos = [
|
||||||
|
[{"role": "user", "content": q}, {"role": "assistant", "content": r}]
|
||||||
|
for q, r in zip(questions, answer)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 使用tokenizer.apply_chat_template格式化对话
|
||||||
|
texts = [
|
||||||
|
tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
|
||||||
|
for convo in convos
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"text": texts}
|
||||||
|
|
||||||
def get_model_name(model):
|
def get_model_name(model):
|
||||||
return os.path.basename(model.name_or_path)
|
return os.path.basename(model.name_or_path)
|
Loading…
x
Reference in New Issue
Block a user