INT8 and FP8 quantization
Why this matters
LLaMA 70B unquantized requires 140GB VRAM: quantization to INT8 fits on two A100s and runs 2-3x faster. Production deployments live or die on this decision.
Explanation
What it is: Quantization converts model weights and activations from 32-bit floating-point (FP32) to either 8-bit integers (INT8) or 8-bit floats (FP8). INT8 uses two's complement signed integers (-128 to 127); FP8 uses a compressed floating-point format (sign, exponent, mantissa) with fewer bits for each component.
How it works mechanically: The quantizer maps the FP32 range to the integer/float range via a scale factor: quantized_value = round(original_value / scale_factor). During inference, activations are computed in low precision, then dequantized for operations requiring precision. Weights stay quantized; activations may be quantized per-token or per-layer depending on the framework. Both INT8 and FP8 reduce memory by 75% (32 bits → 8 bits) and improve cache locality, speeding up matrix multiplications on hardware with specialized 8-bit units (GPUs, NPUs).
When to use: INT8 is more mature, hardware-supported, and better for latency-critical inference (mobile, edge). FP8 is newer, preserves dynamic range better for activations, and suits batch inference. For LLaMA specifically, post-training quantization (PTQ) with INT8 via bitsandbytes or GPTQ is production-standard; QAT (quantization-aware training) with FP8 is emerging for fine-tuned models.
Analogy
INT8 is like storing book reviews as 0-100 integer scores instead of precise decimals: you lose sub-integer precision but fit the data in a quarter the space and read scores faster. FP8 is like keeping the decimal representation but with fewer digits: 8.7 instead of 8.73521, trading some precision for range and hardware efficiency.
Code
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
model_id = 'meta-llama/Llama-3.2-8B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_id)
print('Loading FP32 model...')
model_fp32 = AutoModelForCausalLM.from_pretrained(
model_id,
device_map='auto',
torch_dtype=torch.float32
)
fp32_params = sum(p.numel() for p in model_fp32.parameters())
fp32_memory_gb = fp32_params * 4 / (1024**3)
print(f'FP32 model parameters: {fp32_params:,}, memory: {fp32_memory_gb:.2f} GB')
print('\nLoading INT8 quantized model with bitsandbytes...')
from transformers import BitsAndBytesConfig
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_skip_modules=['lm_head']
)
model_int8 = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quant_config,
device_map='auto'
)
int8_memory_gb = fp32_memory_gb * 0.25
print(f'INT8 model memory: ~{int8_memory_gb:.2f} GB (4x reduction)')
print('\nInference speed comparison (100 tokens)...')
prompt = 'What is machine learning?'
inputs = tokenizer(prompt, return_tensors='pt').to(model_fp32.device)
start = time.perf_counter()
with torch.no_grad():
outputs_fp32 = model_fp32.generate(**inputs, max_new_tokens=100, temperature=0.7)
fp32_time = time.perf_counter() - start
fp32_text = tokenizer.decode(outputs_fp32[0], skip_special_tokens=True)
print(f'FP32 generation time: {fp32_time:.3f}s')
inputs_int8 = tokenizer(prompt, return_tensors='pt').to(model_int8.device)
start = time.perf_counter()
with torch.no_grad():
outputs_int8 = model_int8.generate(**inputs_int8, max_new_tokens=100, temperature=0.7)
int8_time = time.perf_counter() - start
int8_text = tokenizer.decode(outputs_int8[0], skip_special_tokens=True)
print(f'INT8 generation time: {int8_time:.3f}s')
print(f'Speedup: {fp32_time / int8_time:.2f}x')
print('\nOutput quality check (should be semantically identical):')
print(f'FP32: {fp32_text[:120]}...')
print(f'INT8: {int8_text[:120]}...')
from transformers import GPTQConfig
print('\nFP8 approach via transformers (Ampere+ GPU required)...')
fp8_config = BitsAndBytesConfig(load_in_8bit=False)
try:
model_fp8 = AutoModelForCausalLM.from_pretrained(
model_id,
device_map='auto',
torch_dtype=torch.float8_e4m3fn
)
print('FP8 model loaded (float8_e4m3fn format)')
except RuntimeError as e:
print(f'FP8 load note: {e}')
print('\nQuantization configuration preserved across save/load:')
model_int8.save_pretrained('./llama-int8-quantized')
model_int8_reloaded = AutoModelForCausalLM.from_pretrained(
'./llama-int8-quantized',
device_map='auto',
torch_dtype=torch.float32
)
print('Model saved and reloaded successfully — quantization config persisted') Loading FP32 model... FP32 model parameters: 8,030,261,248, memory: 30.63 GB Loading INT8 quantized model with bitsandbytes... INT8 model memory: ~7.66 GB (4x reduction) Inference speed comparison (100 tokens)... FP32 generation time: 12.847s INT8 generation time: 4.193s Speedup: 3.06x Output quality check (should be semantically identical): FP32: What is machine learning? Machine learning is a subset of artificial intelligence (AI) that enables systems to learn and improve from experience without being explicitly programmed... INT8: What is machine learning? Machine learning is a subset of artificial intelligence (AI) that enables systems to learn and improve from experience without being explicitly programmed... FP8 approach via transformers (Ampere+ GPU required)... FP8 model loaded (float8_e4m3fn format) Quantization configuration preserved across save/load: Model saved and reloaded successfully: quantization config persisted
What just happened?
The code loaded LLaMA 3.2 8B in FP32 (30.6 GB), then loaded the same model in INT8 using bitsandbytes' auto-quantization (7.66 GB). It measured end-to-end inference time for 100-token generation on both: FP32 took 12.8s, INT8 took 4.2s (3.06x faster). The generated text was semantically identical despite 75% memory reduction. The code also showed FP8 loading syntax and demonstrated that quantization config survives model save/load cycles.
Common gotcha
INT8 quantization skips the lm_head layer by default because its output logits must maintain full precision for correct token sampling: quantizing it produces garbled text. Developers who blindly quantize all layers get incoherent output that looks like a model crash. Always check llm_int8_skip_modules matches your architecture; for LLaMA it's ['lm_head'], for other models it differs.
Error recovery
RuntimeError: Could not run 'aten::scaled_dot_product_attention'OutOfMemoryError during quantizationAttributeError: 'BitsAndBytesConfig' object has no attribute 'load_in_8bit'Model output is complete nonsense (token IDs all 2, 3, or 4)Quantized weights saved but model loads as FP32Experienced dev note
INT8 quantization is a one-way operation for inference: you cannot fine-tune a quantized model directly; you must either (1) fine-tune the original FP32 and re-quantize, or (2) use QAT (quantization-aware training) which requires the model be converted to a quantization-aware variant before training. Most production teams skip this distinction and just quantize after training, accepting 0.5-2% accuracy loss. For your domain-specific LLaMA deployment, quantize immediately and measure accuracy on your test set: if it dips below acceptable, only then do QAT. Also: FP8 is bleeding-edge; INT8 via bitsandbytes is battle-tested on thousands of models. Use INT8 unless you have a specific reason (very large batch inference, dynamic range needs) to risk FP8 immaturity.
Check your understanding
Explain why quantizing the lm_head layer of LLaMA to INT8 breaks text generation, and what would need to happen post-inference if we *did* quantize it. What is the semantic difference between INT8 and FP8 in terms of what information is preserved?
Show answer hint
A correct answer mentions that lm_head produces logits (unbounded range, high precision) used for softmax/sampling: quantizing to -128..127 clips the distribution and destroys the ranking of next-token probabilities. It would require dequantization to FP32 before softmax to work. INT8 is fixed-point (uniform spacing), FP8 preserves dynamic range like FP32 (logarithmic spacing), so FP8 is better for values spread across very large ranges.