gzhu-biyesheji/frontend/chat_page.py

67 lines
2.2 KiB
Python

import gradio as gr
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer
def chat_page():
with gr.Blocks() as demo:
import random
import time
gr.Markdown("## 聊天")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list):
model = get_model()
tokenizer = get_tokenizer()
print(tokenizer)
print(model)
# 获取用户的最新消息
user_message = history[-1]["content"]
# 使用 tokenizer 对消息进行预处理
messages = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
# 使用 TextStreamer 进行流式生成
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# 调用模型进行推理
generated_text = ""
for new_token in model.generate(
input_ids=inputs,
streamer=text_streamer,
max_new_tokens=1024,
use_cache=False,
temperature=1.5,
min_p=0.1,
):
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
history.append({"role": "assistant", "content": generated_text})
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
return demo
if __name__ == "__main__":
from model_manage_page import model_manage_page
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()