Fix Llama slow inference
Quick answer
Fix slow inference with
Llama by using 4-bit quantization via BitsAndBytesConfig and enabling device_map="auto" to leverage GPU acceleration. Also, batch inputs and use efficient tokenizers to reduce overhead and improve throughput.PREREQUISITES
Python 3.8+pip install transformers>=4.30.0pip install bitsandbytesCUDA-enabled GPU recommended
Setup
Install the required packages for optimized Llama inference including transformers and bitsandbytes for 4-bit quantization support. Ensure you have a CUDA-enabled GPU for best performance.
pip install transformers bitsandbytes Step by step
Use BitsAndBytesConfig to load the Llama model in 4-bit precision and enable automatic device mapping to utilize GPU efficiently. Batch your inputs to reduce overhead and speed up inference.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# Configure 4-bit quantization
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# Load tokenizer and model with quantization and device map
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quant_config,
device_map="auto"
)
# Prepare batch inputs
texts = ["Hello, how are you?", "What is the capital of France?"]
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
# Generate outputs
outputs = model.generate(**inputs, max_new_tokens=50)
# Decode and print results
for i, output in enumerate(outputs):
print(f"Input: {texts[i]}")
print(f"Output: {tokenizer.decode(output, skip_special_tokens=True)}\n") output
Input: Hello, how are you? Output: Hello, how are you? I am doing well, thank you! Input: What is the capital of France? Output: The capital of France is Paris.
Common variations
- Use
load_in_8bit=Trueinstead of 4-bit if 4-bit quantization causes instability. - For CPU-only environments, avoid
device_map="auto"and usetorch_dtype=torch.float16cautiously. - Use the
mistralorgroqOpenAI-compatible APIs to access Llama models with faster cloud inference.
Troubleshooting
- If inference is still slow, verify your GPU drivers and CUDA installation are up to date.
- Check that
bitsandbytesis installed correctly; reinstall if you encounter import errors. - Reduce batch size if you run out of GPU memory.
- Use
torch.cuda.empty_cache()between runs to free GPU memory.
Key Takeaways
- Use 4-bit quantization with BitsAndBytesConfig to speed up Llama inference.
- Enable device_map="auto" to leverage GPU acceleration automatically.
- Batch inputs to reduce overhead and improve throughput.
- Ensure CUDA and GPU drivers are properly installed and updated.
- Consider cloud APIs for Llama models if local inference remains slow.