diff --git a/frontend/chat_page.py b/frontend/chat_page.py index 7ab6806..736c470 100644 --- a/frontend/chat_page.py +++ b/frontend/chat_page.py @@ -10,91 +10,81 @@ from global_var import get_model, get_tokenizer # 假设这两个函数能正确 def chat_page(): with gr.Blocks() as demo: - gr.Markdown("## 聊天") - chatbot = gr.Chatbot(type="messages", label="聊天机器人") # 使用 messages 类型,label 可选 - msg = gr.Textbox(label="输入消息") - clear = gr.Button("清除对话") + # 聊天框 + gr.Markdown("## 对话") + with gr.Row(): + with gr.Column(scale=4): + chatbot = gr.Chatbot(type="messages", label="聊天机器人") + msg = gr.Textbox(label="输入消息") + + with gr.Column(scale=1): + # 新增超参数输入框 + max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024") + temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8") + top_p_input = gr.Textbox(label="Top-p 采样", value="0.95") + repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1") + clear = gr.Button("清除对话") def user(user_message, history: list): - # 清空输入框,并将用户消息添加到 history - # 确保 history 是 list of lists 或者 list of tuples (根据 Gradio 版本和 Chatbot type) - # 对于 type="messages",期望的格式是 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] return "", history + [{"role": "user", "content": user_message}] - def bot(history: list): + def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty): model = get_model() tokenizer = get_tokenizer() - if not history: # 避免 history 为空时出错 + if not history: yield history return - # 检查模型和分词器是否存在 if model is None or tokenizer is None: - history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"}) - yield history - return + history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"}) + yield history + return try: - # --- 关键改动点 --- - # 1. 使用完整的 history (或者根据需要裁剪) 来创建输入 - # apply_chat_template 通常能处理这种 [{"role": ..., "content": ...}] 格式 - # 确保你的模型和 tokenizer 支持 chat template inputs = tokenizer.apply_chat_template( history, tokenize=True, - add_generation_prompt=True, # 对很多指令调优模型是必要的 + add_generation_prompt=True, return_tensors="pt", - ).to(model.device) # 将输入张量移动到模型所在的设备 + ).to(model.device) - # 2. 使用 TextIteratorStreamer - # skip_prompt=True: 不在流中包含原始输入 prompt - # skip_special_tokens=True: 不在流中包含特殊 token (如 ) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - # 3. 将 model.generate 放入单独的线程中运行 - # 这样 Gradio 的主线程不会被阻塞,可以接收流式输出 + # 将超参数转换为数值类型 generation_kwargs = dict( input_ids=inputs, streamer=streamer, - max_new_tokens=1024, - # temperature=1.5, # 1.5 通常太高,容易产生随机无意义内容,建议 0.7-1.0 - # min_p=0.1, # min_p 不常用,top_p 更常见 - do_sample=True, # 启用采样,让 temperature/top_p 生效 - temperature=0.8, # 稍微降低温度 - top_p=0.95, # 使用 top_p 采样 - repetition_penalty=1.1, # 轻微惩罚重复 - use_cache=False # 通常可以开启以加速生成,除非有特殊原因 + max_new_tokens=int(max_new_tokens), + temperature=float(temperature), + top_p=float(top_p), + repetition_penalty=float(repetition_penalty), + do_sample=True, + use_cache=False ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() - # 4. 在 history 中添加一个空的 assistant 回复占位符 history.append({"role": "assistant", "content": ""}) - # 5. 迭代 streamer,实时更新 history 中最后一条消息的内容 for new_text in streamer: - if new_text: # 确保不是空字符串 + if new_text: history[-1]["content"] += new_text - yield history # yield 更新后的 history 给 Gradio Chatbot + yield history except Exception as e: - # 异常处理,将错误信息显示在聊天框 - import traceback - error_message = f"生成回复时出错:\n{traceback.format_exc()}" - # 如果最后一条消息是助手的空消息,覆盖它;否则追加 - if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "": + import traceback + error_message = f"生成回复时出错:\n{traceback.format_exc()}" + if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "": history[-1]["content"] = error_message - else: + else: history.append({"role": "assistant", "content": error_message}) - yield history # 显示错误信息 + yield history - - # .then() 中 bot 函数的输出直接连接到 chatbot 组件 + # 更新 .then() 调用,将超参数传递给 bot 函数 msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( - bot, chatbot, chatbot + bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot ) - # 清除按钮点击后,返回一个空列表来清空 chatbot clear.click(lambda: [], None, chatbot, queue=False) return demo