feat(frontend): 初步实现聊天页面的智能回复功能
This commit is contained in:
@@ -2,7 +2,7 @@ import gradio as gr
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||||
from global_var import model,tokenizer
|
from global_var import get_model, get_tokenizer
|
||||||
|
|
||||||
def chat_page():
|
def chat_page():
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
@@ -17,11 +17,39 @@ def chat_page():
|
|||||||
return "", history + [{"role": "user", "content": user_message}]
|
return "", history + [{"role": "user", "content": user_message}]
|
||||||
|
|
||||||
def bot(history: list):
|
def bot(history: list):
|
||||||
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
|
model = get_model()
|
||||||
history.append({"role": "assistant", "content": ""})
|
tokenizer = get_tokenizer()
|
||||||
for character in bot_message:
|
print(tokenizer)
|
||||||
history[-1]['content'] += character
|
print(model)
|
||||||
time.sleep(0.1)
|
|
||||||
|
# 获取用户的最新消息
|
||||||
|
user_message = history[-1]["content"]
|
||||||
|
|
||||||
|
# 使用 tokenizer 对消息进行预处理
|
||||||
|
messages = [{"role": "user", "content": user_message}]
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
# 使用 TextStreamer 进行流式生成
|
||||||
|
from transformers import TextStreamer
|
||||||
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
|
||||||
|
|
||||||
|
# 调用模型进行推理
|
||||||
|
generated_text = ""
|
||||||
|
for new_token in model.generate(
|
||||||
|
input_ids=inputs,
|
||||||
|
streamer=text_streamer,
|
||||||
|
max_new_tokens=1024,
|
||||||
|
use_cache=False,
|
||||||
|
temperature=1.5,
|
||||||
|
min_p=0.1,
|
||||||
|
):
|
||||||
|
generated_text += tokenizer.decode(new_token, skip_special_tokens=True)
|
||||||
|
history.append({"role": "assistant", "content": generated_text})
|
||||||
yield history
|
yield history
|
||||||
|
|
||||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||||
@@ -32,4 +60,8 @@ def chat_page():
|
|||||||
return demo
|
return demo
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_page().queue().launch()
|
from model_manage_page import model_manage_page
|
||||||
|
# 装载两个页面
|
||||||
|
demo = gr.TabbedInterface([model_manage_page(), chat_page()], ["模型管理", "聊天"])
|
||||||
|
demo.queue()
|
||||||
|
demo.launch()
|
Reference in New Issue
Block a user