fix(frontend): 修复聊天页面并的流式回复

- 导入 Thread 和 TextIteratorStreamer 以支持流式生成
- 重新设计 user 和 bot 函数,优化对话历史处理
- 添加异常处理和错误信息显示
- 改进模型和分词器的加载逻辑
- 优化聊天页面布局和交互
This commit is contained in:
carry 2025-04-11 18:33:31 +08:00
parent fb6157af05
commit 61672021ef

View File

@ -1,62 +1,102 @@
import gradio as gr
import sys
from pathlib import Path
from threading import Thread # 需要导入 Thread
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
# 假设 global_var.py 在父目录
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
def chat_page():
with gr.Blocks() as demo:
import random
import time
gr.Markdown("## 聊天")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
clear = gr.Button("Clear")
chatbot = gr.Chatbot(type="messages", label="聊天机器人") # 使用 messages 类型label 可选
msg = gr.Textbox(label="输入消息")
clear = gr.Button("清除对话")
def user(user_message, history: list):
# 清空输入框,并将用户消息添加到 history
# 确保 history 是 list of lists 或者 list of tuples (根据 Gradio 版本和 Chatbot type)
# 对于 type="messages",期望的格式是 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list):
model = get_model()
tokenizer = get_tokenizer()
print(tokenizer)
print(model)
# 获取用户的最新消息
user_message = history[-1]["content"]
# 使用 tokenizer 对消息进行预处理
messages = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
# 使用 TextStreamer 进行流式生成
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# 调用模型进行推理
generated_text = ""
for new_token in model.generate(
input_ids=inputs,
streamer=text_streamer,
max_new_tokens=1024,
use_cache=False,
temperature=1.5,
min_p=0.1,
):
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
history.append({"role": "assistant", "content": generated_text})
if not history: # 避免 history 为空时出错
yield history
return
# 检查模型和分词器是否存在
if model is None or tokenizer is None:
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
yield history
return
try:
# --- 关键改动点 ---
# 1. 使用完整的 history (或者根据需要裁剪) 来创建输入
# apply_chat_template 通常能处理这种 [{"role": ..., "content": ...}] 格式
# 确保你的模型和 tokenizer 支持 chat template
inputs = tokenizer.apply_chat_template(
history,
tokenize=True,
add_generation_prompt=True, # 对很多指令调优模型是必要的
return_tensors="pt",
).to(model.device) # 将输入张量移动到模型所在的设备
# 2. 使用 TextIteratorStreamer
# skip_prompt=True: 不在流中包含原始输入 prompt
# skip_special_tokens=True: 不在流中包含特殊 token (如 <eos>)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 3. 将 model.generate 放入单独的线程中运行
# 这样 Gradio 的主线程不会被阻塞,可以接收流式输出
generation_kwargs = dict(
input_ids=inputs,
streamer=streamer,
max_new_tokens=1024,
# temperature=1.5, # 1.5 通常太高,容易产生随机无意义内容,建议 0.7-1.0
# min_p=0.1, # min_p 不常用top_p 更常见
do_sample=True, # 启用采样,让 temperature/top_p 生效
temperature=0.8, # 稍微降低温度
top_p=0.95, # 使用 top_p 采样
repetition_penalty=1.1, # 轻微惩罚重复
use_cache=False # 通常可以开启以加速生成,除非有特殊原因
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 4. 在 history 中添加一个空的 assistant 回复占位符
history.append({"role": "assistant", "content": ""})
# 5. 迭代 streamer实时更新 history 中最后一条消息的内容
for new_text in streamer:
if new_text: # 确保不是空字符串
history[-1]["content"] += new_text
yield history # yield 更新后的 history 给 Gradio Chatbot
except Exception as e:
# 异常处理,将错误信息显示在聊天框
import traceback
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
# 如果最后一条消息是助手的空消息,覆盖它;否则追加
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
history[-1]["content"] = error_message
else:
history.append({"role": "assistant", "content": error_message})
yield history # 显示错误信息
# .then() 中 bot 函数的输出直接连接到 chatbot 组件
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
# 清除按钮点击后,返回一个空列表来清空 chatbot
clear.click(lambda: [], None, chatbot, queue=False)
return demo
if __name__ == "__main__":