gzhu-biyesheji/frontend/chat_page.py
carry 61672021ef fix(frontend): 修复聊天页面并的流式回复
- 导入 Thread 和 TextIteratorStreamer 以支持流式生成
- 重新设计 user 和 bot 函数,优化对话历史处理
- 添加异常处理和错误信息显示
- 改进模型和分词器的加载逻辑
- 优化聊天页面布局和交互
2025-04-11 18:33:31 +08:00

107 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import sys
from pathlib import Path
from threading import Thread # 需要导入 Thread
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
# 假设 global_var.py 在父目录
sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
def chat_page():
with gr.Blocks() as demo:
gr.Markdown("## 聊天")
chatbot = gr.Chatbot(type="messages", label="聊天机器人") # 使用 messages 类型label 可选
msg = gr.Textbox(label="输入消息")
clear = gr.Button("清除对话")
def user(user_message, history: list):
# 清空输入框,并将用户消息添加到 history
# 确保 history 是 list of lists 或者 list of tuples (根据 Gradio 版本和 Chatbot type)
# 对于 type="messages",期望的格式是 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list):
model = get_model()
tokenizer = get_tokenizer()
if not history: # 避免 history 为空时出错
yield history
return
# 检查模型和分词器是否存在
if model is None or tokenizer is None:
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
yield history
return
try:
# --- 关键改动点 ---
# 1. 使用完整的 history (或者根据需要裁剪) 来创建输入
# apply_chat_template 通常能处理这种 [{"role": ..., "content": ...}] 格式
# 确保你的模型和 tokenizer 支持 chat template
inputs = tokenizer.apply_chat_template(
history,
tokenize=True,
add_generation_prompt=True, # 对很多指令调优模型是必要的
return_tensors="pt",
).to(model.device) # 将输入张量移动到模型所在的设备
# 2. 使用 TextIteratorStreamer
# skip_prompt=True: 不在流中包含原始输入 prompt
# skip_special_tokens=True: 不在流中包含特殊 token (如 <eos>)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 3. 将 model.generate 放入单独的线程中运行
# 这样 Gradio 的主线程不会被阻塞,可以接收流式输出
generation_kwargs = dict(
input_ids=inputs,
streamer=streamer,
max_new_tokens=1024,
# temperature=1.5, # 1.5 通常太高,容易产生随机无意义内容,建议 0.7-1.0
# min_p=0.1, # min_p 不常用top_p 更常见
do_sample=True, # 启用采样,让 temperature/top_p 生效
temperature=0.8, # 稍微降低温度
top_p=0.95, # 使用 top_p 采样
repetition_penalty=1.1, # 轻微惩罚重复
use_cache=False # 通常可以开启以加速生成,除非有特殊原因
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 4. 在 history 中添加一个空的 assistant 回复占位符
history.append({"role": "assistant", "content": ""})
# 5. 迭代 streamer实时更新 history 中最后一条消息的内容
for new_text in streamer:
if new_text: # 确保不是空字符串
history[-1]["content"] += new_text
yield history # yield 更新后的 history 给 Gradio Chatbot
except Exception as e:
# 异常处理,将错误信息显示在聊天框
import traceback
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
# 如果最后一条消息是助手的空消息,覆盖它;否则追加
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
history[-1]["content"] = error_message
else:
history.append({"role": "assistant", "content": error_message})
yield history # 显示错误信息
# .then() 中 bot 函数的输出直接连接到 chatbot 组件
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
# 清除按钮点击后,返回一个空列表来清空 chatbot
clear.click(lambda: [], None, chatbot, queue=False)
return demo
if __name__ == "__main__":
from model_manage_page import model_manage_page
# 装载两个页面
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
demo.queue()
demo.launch()