From d6885fc95cef9684df5da6bff6d697bff34b0e32 Mon Sep 17 00:00:00 2001 From: Anton Bacaj Date: Mon, 26 Jun 2023 15:51:42 +0000 Subject: [PATCH] turn style conversation, enable history --- inference.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/inference.py b/inference.py index 2b7d821..dec4a63 100644 --- a/inference.py +++ b/inference.py @@ -1,3 +1,4 @@ +import sys import os from dataclasses import dataclass, asdict from ctransformers import AutoModelForCausalLM, AutoConfig @@ -27,10 +28,6 @@ def format_prompt(system_prompt: str, user_prompt: str): return f"{system_prompt}{user_prompt}{assistant_prompt}" -def format_output(user_prompt: str): - return f"[user]: {user_prompt}\n[assistant]:" - - def generate( llm: AutoModelForCausalLM, generation_config: GenerationConfig, @@ -56,17 +53,7 @@ if __name__ == "__main__": config=config, ) - system_prompt = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers." - - user_prompts = [ - "What is 2 + 2?", - "What is 12 + 2?", - "What is 5 + 7?", - "What is 3 * 2?", - "What is 4 / 2?", - "Who was the first president of the US?", - "Can humans ever set foot on mars?", - ] + system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers." generation_config = GenerationConfig( temperature=0.2, @@ -75,18 +62,19 @@ if __name__ == "__main__": repetition_penalty=1.0, max_new_tokens=512, # adjust as needed seed=42, - reset=True, # reset history (cache) + reset=False, # reset history (cache) stream=True, # streaming per word/token threads=int(os.cpu_count() / 2), # adjust for your CPU stop=["<|im_end|>", "|<"], ) - for user_prompt in user_prompts: - generator = generate(llm, generation_config, system_prompt, user_prompt) - print(format_output(user_prompt), end=" ", flush=True) + user_prefix = "[user]: " + assistant_prefix = f"[assistant]:" + + while True: + user_prompt = input(user_prefix) + generator = generate(llm, generation_config, system_prompt, user_prompt.strip()) + print(assistant_prefix, end=" ", flush=True) for word in generator: print(word, end="", flush=True) - - # print empty line print("") - print(80 * "=")