import gradio as gr import sys from pathlib import Path from threading import Thread # 需要导入 Thread from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer # 假设 global_var.py 在父目录 sys.path.append(str(Path(__file__).resolve().parent.parent)) from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器 def chat_page(): with gr.Blocks() as demo: # 聊天框 gr.Markdown("## 对话") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot(type="messages", label="聊天机器人") msg = gr.Textbox(label="输入消息") with gr.Column(scale=1): # 新增超参数输入框 max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024") temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8") top_p_input = gr.Textbox(label="Top-p 采样", value="0.95") repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1") clear = gr.Button("清除对话") def user(user_message, history: list): return "", history + [{"role": "user", "content": user_message}] def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty): model = get_model() tokenizer = get_tokenizer() if not history: yield history return if model is None or tokenizer is None: history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"}) yield history return try: inputs = tokenizer.apply_chat_template( history, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # 将超参数转换为数值类型 generation_kwargs = dict( input_ids=inputs, streamer=streamer, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), do_sample=True, use_cache=False ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() history.append({"role": "assistant", "content": ""}) for new_text in streamer: if new_text: history[-1]["content"] += new_text yield history except Exception as e: import traceback error_message = f"生成回复时出错:\n{traceback.format_exc()}" if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "": history[-1]["content"] = error_message else: history.append({"role": "assistant", "content": error_message}) yield history # 更新 .then() 调用,将超参数传递给 bot 函数 msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot ) clear.click(lambda: [], None, chatbot, queue=False) return demo if __name__ == "__main__": from model_manage_page import model_manage_page # 装载两个页面 demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"]) demo.queue() demo.launch()