cleanup code
This commit is contained in:
parent
dcee968bb4
commit
c6e2d7305b
3 changed files with 46 additions and 22 deletions
|
|
@ -2,13 +2,11 @@
|
|||
|
||||
Run inference on the latest MPT-30B model using your CPU. This inference code uses a [ggml](https://github.com/ggerganov/llama.cpp) quantized model. To run the model we'll use a library called [ctransformers](https://github.com/marella/ctransformers) that has bindings to ggml in python.
|
||||
|
||||
I recommend a system with 32GB of ram.
|
||||
|
||||
[Inference Demo](https://github.com/abacaj/mpt-30B-inference/assets/7272343/486fc9b1-8216-43cc-93c3-781677235502)
|
||||
|
||||
## Requirements
|
||||
|
||||
I recommend you use docker for this model, it will make everything easier for you. Tested on AMD Epyc CPU.
|
||||
I recommend you use docker for this model, it will make everything easier for you. I recommend a system with 32GB of ram. Tested on AMD Epyc CPU, Python 3.10.
|
||||
|
||||
## Setup
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import os
|
|||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def download_mpt_quant(destination_folder):
|
||||
def download_mpt_quant(destination_folder: str, repo_id: str, model_filename: str):
|
||||
local_path = os.path.relpath(destination_folder)
|
||||
return hf_hub_download(
|
||||
repo_id="TheBloke/mpt-30B-chat-GGML",
|
||||
filename="mpt-30b-chat.ggmlv0.q4_1.bin",
|
||||
repo_id=repo_id,
|
||||
filename=model_filename,
|
||||
cache_dir=local_path,
|
||||
)
|
||||
|
||||
|
|
@ -14,5 +14,7 @@ def download_mpt_quant(destination_folder):
|
|||
if __name__ == "__main__":
|
||||
"""full url: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/blob/main/mpt-30b-chat.ggmlv0.q4_1.bin"""
|
||||
|
||||
destination_folder = "models"
|
||||
download_mpt_quant(destination_folder)
|
||||
repo_id = "TheBloke/mpt-30B-chat-GGML"
|
||||
model_filename = "mpt-30b-chat.ggmlv0.q4_1.bin"
|
||||
destination_folder = "modelz"
|
||||
download_mpt_quant(destination_folder, repo_id, model_filename)
|
||||
|
|
|
|||
52
inference.py
52
inference.py
|
|
@ -1,8 +1,23 @@
|
|||
import os
|
||||
from dataclasses import dataclass, asdict
|
||||
from ctransformers import AutoModelForCausalLM, AutoConfig
|
||||
|
||||
|
||||
def format_prompt(system_prompt, user_prompt):
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
temperature: float
|
||||
top_k: int
|
||||
top_p: float
|
||||
repetition_penalty: float
|
||||
max_new_tokens: int
|
||||
seed: int
|
||||
reset: bool
|
||||
stream: bool
|
||||
threads: int
|
||||
stop: list[str]
|
||||
|
||||
|
||||
def format_prompt(system_prompt: str, user_prompt: str):
|
||||
"""format prompt based on: https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py"""
|
||||
|
||||
system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||
|
|
@ -12,11 +27,16 @@ def format_prompt(system_prompt, user_prompt):
|
|||
return f"{system_prompt}{user_prompt}{assistant_prompt}"
|
||||
|
||||
|
||||
def format_output(user_prompt):
|
||||
def format_output(user_prompt: str):
|
||||
return f"[user]: {user_prompt}\n[assistant]:"
|
||||
|
||||
|
||||
def generate(llm, system_prompt, user_prompt):
|
||||
def generate(
|
||||
llm: AutoModelForCausalLM,
|
||||
generation_config: GenerationConfig,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
):
|
||||
"""run model inference, will return a Generator if streaming is true"""
|
||||
|
||||
return llm(
|
||||
|
|
@ -24,16 +44,7 @@ def generate(llm, system_prompt, user_prompt):
|
|||
system_prompt,
|
||||
user_prompt,
|
||||
),
|
||||
temperature=0.2,
|
||||
top_k=0,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
max_new_tokens=512, # adjust as needed
|
||||
seed=42,
|
||||
reset=True, # reset history (cache)
|
||||
stream=True, # streaming per word/token
|
||||
threads=int(os.cpu_count() / 2), # adjust for your CPU
|
||||
stop=["<|im_end|>", "|<"],
|
||||
**asdict(generation_config),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -57,8 +68,21 @@ if __name__ == "__main__":
|
|||
"Can humans ever set foot on mars?",
|
||||
]
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
temperature=0.2,
|
||||
top_k=0,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
max_new_tokens=512, # adjust as needed
|
||||
seed=42,
|
||||
reset=True, # reset history (cache)
|
||||
stream=True, # streaming per word/token
|
||||
threads=int(os.cpu_count() / 2), # adjust for your CPU
|
||||
stop=["<|im_end|>", "|<"],
|
||||
)
|
||||
|
||||
for user_prompt in user_prompts:
|
||||
generator = generate(llm, system_prompt, user_prompt)
|
||||
generator = generate(llm, generation_config, system_prompt, user_prompt)
|
||||
print(format_output(user_prompt), end=" ", flush=True)
|
||||
for word in generator:
|
||||
print(word, end="", flush=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue