67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
import gradio as gr
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
from global_var import get_model, get_tokenizer
|
|
|
|
def chat_page():
|
|
with gr.Blocks() as demo:
|
|
import random
|
|
import time
|
|
gr.Markdown("## 聊天")
|
|
chatbot = gr.Chatbot(type="messages")
|
|
msg = gr.Textbox()
|
|
clear = gr.Button("Clear")
|
|
|
|
def user(user_message, history: list):
|
|
return "", history + [{"role": "user", "content": user_message}]
|
|
|
|
def bot(history: list):
|
|
model = get_model()
|
|
tokenizer = get_tokenizer()
|
|
print(tokenizer)
|
|
print(model)
|
|
|
|
# 获取用户的最新消息
|
|
user_message = history[-1]["content"]
|
|
|
|
# 使用 tokenizer 对消息进行预处理
|
|
messages = [{"role": "user", "content": user_message}]
|
|
inputs = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
return_tensors="pt",
|
|
).to("cuda")
|
|
|
|
# 使用 TextStreamer 进行流式生成
|
|
from transformers import TextStreamer
|
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
|
|
|
|
# 调用模型进行推理
|
|
generated_text = ""
|
|
for new_token in model.generate(
|
|
input_ids=inputs,
|
|
streamer=text_streamer,
|
|
max_new_tokens=1024,
|
|
use_cache=False,
|
|
temperature=1.5,
|
|
min_p=0.1,
|
|
):
|
|
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
|
|
history.append({"role": "assistant", "content": generated_text})
|
|
yield history
|
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
|
bot, chatbot, chatbot
|
|
|
|
)
|
|
clear.click(lambda: None, None, chatbot, queue=False)
|
|
return demo
|
|
|
|
if __name__ == "__main__":
|
|
from model_manage_page import model_manage_page
|
|
# 装载两个页面
|
|
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
|
|
demo.queue()
|
|
demo.launch() |