cleanup code

This commit is contained in:
Anton Bacaj 2023-06-26 07:10:40 +00:00
parent dcee968bb4
commit c6e2d7305b
3 changed files with 46 additions and 22 deletions

View file

@ -2,13 +2,11 @@
Run inference on the latest MPT-30B model using your CPU. This inference code uses a [ggml](https://github.com/ggerganov/llama.cpp) quantized model. To run the model we'll use a library called [ctransformers](https://github.com/marella/ctransformers) that has bindings to ggml in python.
I recommend a system with 32GB of ram.
[Inference Demo](https://github.com/abacaj/mpt-30B-inference/assets/7272343/486fc9b1-8216-43cc-93c3-781677235502)
## Requirements
I recommend you use docker for this model, it will make everything easier for you. Tested on AMD Epyc CPU.
I recommend you use docker for this model, it will make everything easier for you. I recommend a system with 32GB of ram. Tested on AMD Epyc CPU, Python 3.10.
## Setup

View file

@ -2,11 +2,11 @@ import os
from huggingface_hub import hf_hub_download
def download_mpt_quant(destination_folder):
def download_mpt_quant(destination_folder: str, repo_id: str, model_filename: str):
local_path = os.path.relpath(destination_folder)
return hf_hub_download(
repo_id="TheBloke/mpt-30B-chat-GGML",
filename="mpt-30b-chat.ggmlv0.q4_1.bin",
repo_id=repo_id,
filename=model_filename,
cache_dir=local_path,
)
@ -14,5 +14,7 @@ def download_mpt_quant(destination_folder):
if __name__ == "__main__":
"""full url: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/blob/main/mpt-30b-chat.ggmlv0.q4_1.bin"""
destination_folder = "models"
download_mpt_quant(destination_folder)
repo_id = "TheBloke/mpt-30B-chat-GGML"
model_filename = "mpt-30b-chat.ggmlv0.q4_1.bin"
destination_folder = "modelz"
download_mpt_quant(destination_folder, repo_id, model_filename)

View file

@ -1,8 +1,23 @@
import os
from dataclasses import dataclass, asdict
from ctransformers import AutoModelForCausalLM, AutoConfig
def format_prompt(system_prompt, user_prompt):
@dataclass
class GenerationConfig:
temperature: float
top_k: int
top_p: float
repetition_penalty: float
max_new_tokens: int
seed: int
reset: bool
stream: bool
threads: int
stop: list[str]
def format_prompt(system_prompt: str, user_prompt: str):
"""format prompt based on: https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py"""
system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
@ -12,11 +27,16 @@ def format_prompt(system_prompt, user_prompt):
return f"{system_prompt}{user_prompt}{assistant_prompt}"
def format_output(user_prompt):
def format_output(user_prompt: str):
return f"[user]: {user_prompt}\n[assistant]:"
def generate(llm, system_prompt, user_prompt):
def generate(
llm: AutoModelForCausalLM,
generation_config: GenerationConfig,
system_prompt: str,
user_prompt: str,
):
"""run model inference, will return a Generator if streaming is true"""
return llm(
@ -24,16 +44,7 @@ def generate(llm, system_prompt, user_prompt):
system_prompt,
user_prompt,
),
temperature=0.2,
top_k=0,
top_p=0.9,
repetition_penalty=1.0,
max_new_tokens=512, # adjust as needed
seed=42,
reset=True, # reset history (cache)
stream=True, # streaming per word/token
threads=int(os.cpu_count() / 2), # adjust for your CPU
stop=["<|im_end|>", "|<"],
**asdict(generation_config),
)
@ -57,8 +68,21 @@ if __name__ == "__main__":
"Can humans ever set foot on mars?",
]
generation_config = GenerationConfig(
temperature=0.2,
top_k=0,
top_p=0.9,
repetition_penalty=1.0,
max_new_tokens=512, # adjust as needed
seed=42,
reset=True, # reset history (cache)
stream=True, # streaming per word/token
threads=int(os.cpu_count() / 2), # adjust for your CPU
stop=["<|im_end|>", "|<"],
)
for user_prompt in user_prompts:
generator = generate(llm, system_prompt, user_prompt)
generator = generate(llm, generation_config, system_prompt, user_prompt)
print(format_output(user_prompt), end=" ", flush=True)
for word in generator:
print(word, end="", flush=True)