Compare commits
2 Commits
fb6157af05
...
83427aaaba
Author | SHA1 | Date | |
---|---|---|---|
![]() |
83427aaaba | ||
![]() |
61672021ef |
@ -1,62 +1,92 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Thread # 需要导入 Thread
|
||||||
|
from transformers import TextIteratorStreamer # 使用 TextIteratorStreamer
|
||||||
|
|
||||||
|
# 假设 global_var.py 在父目录
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import get_model, get_tokenizer
|
from global_var import get_model, get_tokenizer # 假设这两个函数能正确获取模型和分词器
|
||||||
|
|
||||||
def chat_page():
|
def chat_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
import random
|
# 聊天框
|
||||||
import time
|
gr.Markdown("## 对话")
|
||||||
gr.Markdown("## 聊天")
|
with gr.Row():
|
||||||
chatbot = gr.Chatbot(type="messages")
|
with gr.Column(scale=4):
|
||||||
msg = gr.Textbox()
|
chatbot = gr.Chatbot(type="messages", label="聊天机器人")
|
||||||
clear = gr.Button("Clear")
|
msg = gr.Textbox(label="输入消息")
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
# 新增超参数输入框
|
||||||
|
max_new_tokens_input = gr.Textbox(label="最大生成长度", value="1024")
|
||||||
|
temperature_input = gr.Textbox(label="温度 (Temperature)", value="0.8")
|
||||||
|
top_p_input = gr.Textbox(label="Top-p 采样", value="0.95")
|
||||||
|
repetition_penalty_input = gr.Textbox(label="重复惩罚", value="1.1")
|
||||||
|
clear = gr.Button("清除对话")
|
||||||
|
|
||||||
def user(user_message, history: list):
|
def user(user_message, history: list):
|
||||||
return "", history + [{"role": "user", "content": user_message}]
|
return "", history + [{"role": "user", "content": user_message}]
|
||||||
|
|
||||||
def bot(history: list):
|
def bot(history: list, max_new_tokens, temperature, top_p, repetition_penalty):
|
||||||
model = get_model()
|
model = get_model()
|
||||||
tokenizer = get_tokenizer()
|
tokenizer = get_tokenizer()
|
||||||
print(tokenizer)
|
if not history:
|
||||||
print(model)
|
yield history
|
||||||
|
return
|
||||||
|
|
||||||
# 获取用户的最新消息
|
if model is None or tokenizer is None:
|
||||||
user_message = history[-1]["content"]
|
history.append({"role": "assistant", "content": "错误:模型或分词器未加载。"})
|
||||||
|
yield history
|
||||||
# 使用 tokenizer 对消息进行预处理
|
return
|
||||||
messages = [{"role": "user", "content": user_message}]
|
|
||||||
inputs = tokenizer.apply_chat_template(
|
try:
|
||||||
messages,
|
inputs = tokenizer.apply_chat_template(
|
||||||
tokenize=True,
|
history,
|
||||||
add_generation_prompt=True,
|
tokenize=True,
|
||||||
return_tensors="pt",
|
add_generation_prompt=True,
|
||||||
).to("cuda")
|
return_tensors="pt",
|
||||||
|
).to(model.device)
|
||||||
# 使用 TextStreamer 进行流式生成
|
|
||||||
from transformers import TextStreamer
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
|
|
||||||
|
# 将超参数转换为数值类型
|
||||||
# 调用模型进行推理
|
generation_kwargs = dict(
|
||||||
generated_text = ""
|
input_ids=inputs,
|
||||||
for new_token in model.generate(
|
streamer=streamer,
|
||||||
input_ids=inputs,
|
max_new_tokens=int(max_new_tokens),
|
||||||
streamer=text_streamer,
|
temperature=float(temperature),
|
||||||
max_new_tokens=1024,
|
top_p=float(top_p),
|
||||||
use_cache=False,
|
repetition_penalty=float(repetition_penalty),
|
||||||
temperature=1.5,
|
do_sample=True,
|
||||||
min_p=0.1,
|
use_cache=False
|
||||||
):
|
)
|
||||||
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
history.append({"role": "assistant", "content": generated_text})
|
thread.start()
|
||||||
|
|
||||||
|
history.append({"role": "assistant", "content": ""})
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
if new_text:
|
||||||
|
history[-1]["content"] += new_text
|
||||||
|
yield history
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_message = f"生成回复时出错:\n{traceback.format_exc()}"
|
||||||
|
if history and history[-1]["role"] == "assistant" and history[-1]["content"] == "":
|
||||||
|
history[-1]["content"] = error_message
|
||||||
|
else:
|
||||||
|
history.append({"role": "assistant", "content": error_message})
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
|
# 更新 .then() 调用,将超参数传递给 bot 函数
|
||||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||||
bot, chatbot, chatbot
|
bot, [chatbot, max_new_tokens_input, temperature_input, top_p_input, repetition_penalty_input], chatbot
|
||||||
|
|
||||||
)
|
)
|
||||||
clear.click(lambda: None, None, chatbot, queue=False)
|
|
||||||
|
clear.click(lambda: [], None, chatbot, queue=False)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user