diff --git a/README.md b/README.md index e0489bf..52b1ebb 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,11 @@ Run inference on the latest MPT-30B model using your CPU. This inference code uses a [ggml](https://github.com/ggerganov/llama.cpp) quantized model. To run the model we'll use a library called [ctransformers](https://github.com/marella/ctransformers) that has bindings to ggml in python. -I recommend a system with 32GB of ram. - [Inference Demo](https://github.com/abacaj/mpt-30B-inference/assets/7272343/486fc9b1-8216-43cc-93c3-781677235502) ## Requirements -I recommend you use docker for this model, it will make everything easier for you. Tested on AMD Epyc CPU. +I recommend you use docker for this model, it will make everything easier for you. I recommend a system with 32GB of ram. Tested on AMD Epyc CPU, Python 3.10. ## Setup diff --git a/download_model.py b/download_model.py index 7601eb2..24e0af5 100644 --- a/download_model.py +++ b/download_model.py @@ -2,11 +2,11 @@ import os from huggingface_hub import hf_hub_download -def download_mpt_quant(destination_folder): +def download_mpt_quant(destination_folder: str, repo_id: str, model_filename: str): local_path = os.path.relpath(destination_folder) return hf_hub_download( - repo_id="TheBloke/mpt-30B-chat-GGML", - filename="mpt-30b-chat.ggmlv0.q4_1.bin", + repo_id=repo_id, + filename=model_filename, cache_dir=local_path, ) @@ -14,5 +14,7 @@ def download_mpt_quant(destination_folder): if __name__ == "__main__": """full url: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/blob/main/mpt-30b-chat.ggmlv0.q4_1.bin""" - destination_folder = "models" - download_mpt_quant(destination_folder) + repo_id = "TheBloke/mpt-30B-chat-GGML" + model_filename = "mpt-30b-chat.ggmlv0.q4_1.bin" + destination_folder = "modelz" + download_mpt_quant(destination_folder, repo_id, model_filename) diff --git a/inference.py b/inference.py index 7a285eb..73f84fa 100644 --- a/inference.py +++ b/inference.py @@ -1,8 +1,23 @@ import os +from dataclasses import dataclass, asdict from ctransformers import AutoModelForCausalLM, AutoConfig -def format_prompt(system_prompt, user_prompt): +@dataclass +class GenerationConfig: + temperature: float + top_k: int + top_p: float + repetition_penalty: float + max_new_tokens: int + seed: int + reset: bool + stream: bool + threads: int + stop: list[str] + + +def format_prompt(system_prompt: str, user_prompt: str): """format prompt based on: https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py""" system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" @@ -12,11 +27,16 @@ def format_prompt(system_prompt, user_prompt): return f"{system_prompt}{user_prompt}{assistant_prompt}" -def format_output(user_prompt): +def format_output(user_prompt: str): return f"[user]: {user_prompt}\n[assistant]:" -def generate(llm, system_prompt, user_prompt): +def generate( + llm: AutoModelForCausalLM, + generation_config: GenerationConfig, + system_prompt: str, + user_prompt: str, +): """run model inference, will return a Generator if streaming is true""" return llm( @@ -24,16 +44,7 @@ def generate(llm, system_prompt, user_prompt): system_prompt, user_prompt, ), - temperature=0.2, - top_k=0, - top_p=0.9, - repetition_penalty=1.0, - max_new_tokens=512, # adjust as needed - seed=42, - reset=True, # reset history (cache) - stream=True, # streaming per word/token - threads=int(os.cpu_count() / 2), # adjust for your CPU - stop=["<|im_end|>", "|<"], + **asdict(generation_config), ) @@ -57,8 +68,21 @@ if __name__ == "__main__": "Can humans ever set foot on mars?", ] + generation_config = GenerationConfig( + temperature=0.2, + top_k=0, + top_p=0.9, + repetition_penalty=1.0, + max_new_tokens=512, # adjust as needed + seed=42, + reset=True, # reset history (cache) + stream=True, # streaming per word/token + threads=int(os.cpu_count() / 2), # adjust for your CPU + stop=["<|im_end|>", "|<"], + ) + for user_prompt in user_prompts: - generator = generate(llm, system_prompt, user_prompt) + generator = generate(llm, generation_config, system_prompt, user_prompt) print(format_output(user_prompt), end=" ", flush=True) for word in generator: print(word, end="", flush=True)