feat(frontend): 增加超参数设置并优化聊天页面布局
- 在聊天页面添加了超参数输入框,包括最大生成长度、温度、Top-p 采样和重复惩罚 - 优化了聊天框的布局,使用 gr.Row() 和 gr.Column() 实现了更合理的界面结构 - 更新了 bot 函数,支持根据用户输入的超参数进行文本生成 - 修复了一些代码格式问题,提高了代码的可读性
This commit is contained in:
parent
61672021ef
commit
83427aaaba
@ -10,91 +10,81 @@ from global_var import get_model, get_tokenizer # 假设这两个函数能正确
|
||||
|
||||
def chat_page():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 聊天")
|
||||
chatbot = gr.Chatbot(type="messages", label="聊天机器人") # 使用 messages 类型,label 可选
|
||||
msg = gr.Textbox(label="输入消息")
|
||||
clear = gr.Button("清除对话")
|
||||
# 聊天框
|
||||
gr.Markdown("## 对话")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
|
||||
msg = gr.Textbox(label="输入消息")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
# 新增超参数输入框
|
||||
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
|
||||
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
|
||||
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
|
||||
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
|
||||
clear = gr.Button("清除对话")
|
||||
|
||||
def user(user_message, history: list):
|
||||
# 清空输入框,并将用户消息添加到 history
|
||||
# 确保 history 是 list of lists 或者 list of tuples (根据 Gradio 版本和 Chatbot type)
|
||||
# 对于 type="messages",期望的格式是 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list):
|
||||
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
if not history: # 避免 history 为空时出错
|
||||
if not history:
|
||||
yield history
|
||||
return
|
||||
|
||||
# 检查模型和分词器是否存在
|
||||
if model is None or tokenizer is None:
|
||||
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
|
||||
yield history
|
||||
return
|
||||
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
|
||||
yield history
|
||||
return
|
||||
|
||||
try:
|
||||
# --- 关键改动点 ---
|
||||
# 1. 使用完整的 history (或者根据需要裁剪) 来创建输入
|
||||
# apply_chat_template 通常能处理这种 [{"role": ..., "content": ...}] 格式
|
||||
# 确保你的模型和 tokenizer 支持 chat template
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
history,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True, # 对很多指令调优模型是必要的
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device) # 将输入张量移动到模型所在的设备
|
||||
).to(model.device)
|
||||
|
||||
# 2. 使用 TextIteratorStreamer
|
||||
# skip_prompt=True: 不在流中包含原始输入 prompt
|
||||
# skip_special_tokens=True: 不在流中包含特殊 token (如 <eos>)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 3. 将 model.generate 放入单独的线程中运行
|
||||
# 这样 Gradio 的主线程不会被阻塞,可以接收流式输出
|
||||
# 将超参数转换为数值类型
|
||||
generation_kwargs = dict(
|
||||
input_ids=inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=1024,
|
||||
# temperature=1.5, # 1.5 通常太高,容易产生随机无意义内容,建议 0.7-1.0
|
||||
# min_p=0.1, # min_p 不常用,top_p 更常见
|
||||
do_sample=True, # 启用采样,让 temperature/top_p 生效
|
||||
temperature=0.8, # 稍微降低温度
|
||||
top_p=0.95, # 使用 top_p 采样
|
||||
repetition_penalty=1.1, # 轻微惩罚重复
|
||||
use_cache=False # 通常可以开启以加速生成,除非有特殊原因
|
||||
max_new_tokens=int(max_new_tokens),
|
||||
temperature=float(temperature),
|
||||
top_p=float(top_p),
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
do_sample=True,
|
||||
use_cache=False
|
||||
)
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# 4. 在 history 中添加一个空的 assistant 回复占位符
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
|
||||
# 5. 迭代 streamer,实时更新 history 中最后一条消息的内容
|
||||
for new_text in streamer:
|
||||
if new_text: # 确保不是空字符串
|
||||
if new_text:
|
||||
history[-1]["content"] += new_text
|
||||
yield history # yield 更新后的 history 给 Gradio Chatbot
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
# 异常处理,将错误信息显示在聊天框
|
||||
import traceback
|
||||
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
|
||||
# 如果最后一条消息是助手的空消息,覆盖它;否则追加
|
||||
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
|
||||
import traceback
|
||||
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
|
||||
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
|
||||
history[-1]["content"] = error_message
|
||||
else:
|
||||
else:
|
||||
history.append({"role": "assistant", "content": error_message})
|
||||
yield history # 显示错误信息
|
||||
yield history
|
||||
|
||||
|
||||
# .then() 中 bot 函数的输出直接连接到 chatbot 组件
|
||||
# 更新 .then() 调用,将超参数传递给 bot 函数
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, chatbot, chatbot
|
||||
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
|
||||
)
|
||||
|
||||
# 清除按钮点击后,返回一个空列表来清空 chatbot
|
||||
clear.click(lambda: [], None, chatbot, queue=False)
|
||||
|
||||
return demo
|
||||
|
Loading…
x
Reference in New Issue
Block a user