Compare commits

..

No commits in common. "83427aaabaa89df528a4af04a1892f168f59ae53" and "fb6157af058092369dbb559f4d4342caccbc3382" have entirely different histories.

View File

@ -1,92 +1,62 @@
import gradio as gr import gradio as gr
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread # 需要导入 Thread
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
# 假设 global_var.py 在父目录
sys.path.append(str(Path(__file__).resolve().parent.parent)) sys.path.append(str(Path(__file__).resolve().parent.parent))
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器 from global_var import get_model, get_tokenizer
def chat_page(): def chat_page():
with gr.Blocks() as demo: with gr.Blocks() as demo:
# 聊天框 import random
gr.Markdown("## 对话") import time
with gr.Row(): gr.Markdown("## 聊天")
with gr.Column(scale=4): chatbot = gr.Chatbot(type="messages")
chatbot = gr.Chatbot(type="messages", label="聊天机器人") msg = gr.Textbox()
msg = gr.Textbox(label="输入消息") clear = gr.Button("Clear")
with gr.Column(scale=1):
# 新增超参数输入框
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
clear = gr.Button("清除对话")
def user(user_message, history: list): def user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}] return "", history + [{"role": "user", "content": user_message}]
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty): def bot(history: list):
model = get_model() model = get_model()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
if not history: print(tokenizer)
yield history print(model)
return
if model is None or tokenizer is None: # 获取用户的最新消息
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"}) user_message = history[-1]["content"]
yield history
return # 使用 tokenizer 对消息进行预处理
messages = [{"role": "user", "content": user_message}]
try: inputs = tokenizer.apply_chat_template(
inputs = tokenizer.apply_chat_template( messages,
history, tokenize=True,
tokenize=True, add_generation_prompt=True,
add_generation_prompt=True, return_tensors="pt",
return_tensors="pt", ).to("cuda")
).to(model.device)
# 使用 TextStreamer 进行流式生成
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# 将超参数转换为数值类型
generation_kwargs = dict( # 调用模型进行推理
input_ids=inputs, generated_text = ""
streamer=streamer, for new_token in model.generate(
max_new_tokens=int(max_new_tokens), input_ids=inputs,
temperature=float(temperature), streamer=text_streamer,
top_p=float(top_p), max_new_tokens=1024,
repetition_penalty=float(repetition_penalty), use_cache=False,
do_sample=True, temperature=1.5,
use_cache=False min_p=0.1,
) ):
thread = Thread(target=model.generate, kwargs=generation_kwargs) generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
thread.start() history.append({"role": "assistant", "content": generated_text})
history.append({"role": "assistant", "content": ""})
for new_text in streamer:
if new_text:
history[-1]["content"] += new_text
yield history
except Exception as e:
import traceback
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
history[-1]["content"] = error_message
else:
history.append({"role": "assistant", "content": error_message})
yield history yield history
# 更新 .then() 调用,将超参数传递给 bot 函数
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot bot, chatbot, chatbot
) )
clear.click(lambda: None, None, chatbot, queue=False)
clear.click(lambda: [], None, chatbot, queue=False)
return demo return demo
if __name__ == "__main__": if __name__ == "__main__":