This repository provides an easy way to run Gemma-2 locally directly from your CLI (or via a Python library) and fast. It is built on top of the 🤗 Transformers and bitsandbytes libraries.
It can be configured to give fully equivalent results to the original implementation, or reduce memory requirements down to just the largest layer in the model!
Important
There is a new "speed" preset for running local-gemma
on CUDA ⚡️ It makes use of torch compile for up to 6x faster generation.
Set --preset="speed"
when using the CLI, or pass preset="speed"
to from_pretrained
when using the Python API
There are two installation flavors of local-gemma
, which you can select depending on your use case:
pipx
- Ideal for CLI
First, follow the installation steps here to install pipx
on your environment.
Then, run one of the commands below, depending on your machine.
pipx install local-gemma"[cuda]"
pipx install local-gemma"[mps]"
pipx install local-gemma"[cpu]"
pip
- Ideal for Python (CLI + API)
Local Gemma-2 can be installed as a hardware-specific Python package through pip
. The only requirement is a Python
installation, details for which can be found here. You can
check you have a Python installed locally by running:
python3 --version
python3 -m venv gemma-venv
source gemma-venv/bin/activate
pip install local-gemma"[cuda]"
pip install local-gemma"[mps]"
pip install local-gemma"[cpu]"
You can chat with the Gemma-2 through an interactive session by calling:
local-gemma
Tip
Local Gemma will check for a Hugging Face "read" token to download the model. You can follow this guide to create a token, and pass it when prompted to log-in. If you're new to Hugging Face and never used a Gemma model, you'll also need to accept the terms at the top of this page.
Alternatively, you can request a single output by passing a prompt, such as:
local-gemma "What is the capital of France?"
By default, this loads the Gemma-2 9b it model. To load the 2b it or 27b it
models, you can set the --model
argument accordingly:
local-gemma --model 2b
Local Gemma-2 will automatically find the most performant preset for your hardware, trading-off speed and memory. For more
control over generation speed and memory usage, set the --preset
argument to one of four available options:
- exact: match the original results by maximizing accuracy
- speed: maximize throughput through torch compile (CUDA only!)
- memory: reducing memory through 4-bit quantization
- memory_extreme: minimizing memory through 4-bit quantization and CPU offload
You can also control the style of the generated text through the --mode
flag, one of "chat", "factual" or "creative":
local-gemma --model 9b --preset memory --mode factual
Finally, you can also pipe in other commands, which will be appended to the prompt after a \n
separator
ls -la | local-gemma "Describe my files"
To see all available decoding options, call local-gemma -h
.
Note
The pipx
installation method creates its own Python environment, so you will need to use the pip
installation method to use this library in a Python script.
Local Gemma-2 can be run locally through a Python interpreter using the familiar Transformers API. To enable a preset,
import the model class from local_gemma
and pass the preset
argument to from_pretrained
. For example, the
following code-snippet loads the Gemma-2 9b model with the "memory" preset:
from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-9b", preset="memory")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model_inputs = tokenizer("The cat sat on the mat", return_attention_mask=True, return_tensors="pt")
generated_ids = model.generate(**model_inputs.to(model.device))
decoded_text = tokenizer.batch_decode(generated_ids)
When using an instruction-tuned model (prefixed by -it
) for conversational use, prepare the inputs using a
chat-template. The following example loads Gemma-2 2b it model
using the "auto" preset, which automatically determines the best preset for the device:
from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-2b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)
generated_ids = model.generate(**model_inputs.to(model.device), max_new_tokens=1024, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids)
Local Gemma-2 provides three presets that trade-off accuracy, speed and memory. The following results highlight this trade-off using Gemma-2 9b with batch size 1 on an 80GB A100 GPU:
Mode | Performance* | Inference Speed (tok/s) | Memory (GB) |
---|---|---|---|
exact | 73.0 | 17.2 | 18.3 |
speed (CUDA-only) | 73.0 | 62.0 | 19.0 |
memory | 72.1 | 13.8 | 7.3 |
memory_extreme | 72.1 | 13.8 | 7.3 |
While an 80GB A100 places the full model on the device, only 3.7GB is required with the memory_extreme
preset. See the
section Preset Details for details.
*Zero-shot results averaged over Wino, ARC Easy, Arc Challenge, PIQA, HellaSwag, MMLU, OpenBook QA.
Mode | 2b Min Memory (GB) | 9b Min Memory (GB) | 27b Min Memory (GB) | Weights dtype | CPU Offload |
---|---|---|---|---|---|
exact | 5.3 | 18.3 | 54.6 | bf16 | no |
speed (CUDA-only) | 5.4 | 19.0 | 55.8 | bf16 | no |
memory | 3.7 | 7.3 | 17.0 | int4 | no |
memory_extreme | 1,8 | 3.7 | 4.7 | int4 | yes |
memory_extreme
implements CPU offloading through
🤗 Accelerate, reducing memory requirements down to the largest layer
in the model (which in this case is the LM head).
Local Gemma-2 is a convenient wrapper around several open-source projects, which we thank explicitly below:
- Transformers for the PyTorch Gemma-2 implementation. Particularly Arthur Zucker for adding the model and the logit soft-capping fixes.
- bitsandbytes for the 4-bit optimization on CUDA.
- quanto for the 4-bit optimization on MPS + CPU.
- Accelerate for the large model loading utilities.
And last but not least, thank you to Google for the pre-trained Gemma-2 checkpoints, all of which you can find on the Hugging Face Hub.