turn style conversation, enable history
This commit is contained in:
parent
7ef5a133a4
commit
d6885fc95c
1 changed files with 10 additions and 22 deletions
32
inference.py
32
inference.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from ctransformers import AutoModelForCausalLM, AutoConfig
|
from ctransformers import AutoModelForCausalLM, AutoConfig
|
||||||
|
|
@ -27,10 +28,6 @@ def format_prompt(system_prompt: str, user_prompt: str):
|
||||||
return f"{system_prompt}{user_prompt}{assistant_prompt}"
|
return f"{system_prompt}{user_prompt}{assistant_prompt}"
|
||||||
|
|
||||||
|
|
||||||
def format_output(user_prompt: str):
|
|
||||||
return f"[user]: {user_prompt}\n[assistant]:"
|
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
llm: AutoModelForCausalLM,
|
llm: AutoModelForCausalLM,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
|
|
@ -56,17 +53,7 @@ if __name__ == "__main__":
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
system_prompt = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
|
system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers."
|
||||||
|
|
||||||
user_prompts = [
|
|
||||||
"What is 2 + 2?",
|
|
||||||
"What is 12 + 2?",
|
|
||||||
"What is 5 + 7?",
|
|
||||||
"What is 3 * 2?",
|
|
||||||
"What is 4 / 2?",
|
|
||||||
"Who was the first president of the US?",
|
|
||||||
"Can humans ever set foot on mars?",
|
|
||||||
]
|
|
||||||
|
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
|
|
@ -75,18 +62,19 @@ if __name__ == "__main__":
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
max_new_tokens=512, # adjust as needed
|
max_new_tokens=512, # adjust as needed
|
||||||
seed=42,
|
seed=42,
|
||||||
reset=True, # reset history (cache)
|
reset=False, # reset history (cache)
|
||||||
stream=True, # streaming per word/token
|
stream=True, # streaming per word/token
|
||||||
threads=int(os.cpu_count() / 2), # adjust for your CPU
|
threads=int(os.cpu_count() / 2), # adjust for your CPU
|
||||||
stop=["<|im_end|>", "|<"],
|
stop=["<|im_end|>", "|<"],
|
||||||
)
|
)
|
||||||
|
|
||||||
for user_prompt in user_prompts:
|
user_prefix = "[user]: "
|
||||||
generator = generate(llm, generation_config, system_prompt, user_prompt)
|
assistant_prefix = f"[assistant]:"
|
||||||
print(format_output(user_prompt), end=" ", flush=True)
|
|
||||||
|
while True:
|
||||||
|
user_prompt = input(user_prefix)
|
||||||
|
generator = generate(llm, generation_config, system_prompt, user_prompt.strip())
|
||||||
|
print(assistant_prefix, end=" ", flush=True)
|
||||||
for word in generator:
|
for word in generator:
|
||||||
print(word, end="", flush=True)
|
print(word, end="", flush=True)
|
||||||
|
|
||||||
# print empty line
|
|
||||||
print("")
|
print("")
|
||||||
print(80 * "=")
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue