How to speed up Stable Diffusion inference
Quick answer
Speed up Stable Diffusion inference by using 4-bit quantization with
BitsAndBytesConfig, enabling mixed precision (FP16), and batching multiple prompts. Use optimized pipelines from diffusers with GPU acceleration and consider model variants like sdxl for faster generation.PREREQUISITES
Python 3.8+pip install diffusers>=0.20.0 torch torchvision accelerate transformers bitsandbytesCUDA-enabled GPU with latest driversBasic knowledge of PyTorch and Hugging Face pipelines
Setup
Install the required Python packages for Stable Diffusion inference acceleration. Ensure you have a CUDA-enabled GPU and the latest NVIDIA drivers installed for best performance.
pip install diffusers torch torchvision accelerate transformers bitsandbytes Step by step
This example demonstrates loading a Stable Diffusion model with 4-bit quantization and FP16 mixed precision, then generating images in batches for faster throughput.
import torch
from diffusers import StableDiffusionPipeline
from transformers import BitsAndBytesConfig
# Configure 4-bit quantization
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
# Load the model with quantization and fp16
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
quantization_config=quant_config,
device_map="auto"
)
pipe.enable_attention_slicing() # Reduce VRAM usage
# Batch prompts for faster inference
prompts = [
"A futuristic cityscape at sunset",
"A fantasy forest with glowing plants",
"A cyberpunk robot portrait"
]
images = pipe(prompts, num_inference_steps=25).images
for i, img in enumerate(images):
img.save(f"output_{i}.png")
print(f"Saved output_{i}.png") output
Saved output_0.png Saved output_1.png Saved output_2.png
Common variations
- Use
pipe.enable_xformers_memory_efficient_attention()if your GPU supports it for faster attention computation. - For CPU-only environments, use smaller models or run with
torch_dtype=torch.float32without quantization. - Adjust
num_inference_stepsto trade off speed and quality; fewer steps speed up inference. - Use asynchronous inference with
asyncioandacceleratefor parallel requests.
Troubleshooting
- If you encounter out-of-memory errors, enable
pipe.enable_attention_slicing()or reduce batch size. - Ensure your CUDA drivers and PyTorch versions are compatible to avoid runtime errors.
- If quantization causes quality degradation, try 8-bit quantization or disable quantization.
- For slow startup, cache the model locally to avoid repeated downloads.
Key Takeaways
- Use 4-bit quantization with
BitsAndBytesConfigto reduce memory and speed up inference. - Batch multiple prompts to maximize GPU utilization and throughput.
- Enable attention slicing and memory-efficient attention for lower VRAM usage.
- Adjust inference steps to balance speed and image quality.
- Keep CUDA drivers and dependencies up to date to avoid runtime issues.