Latency SLA for transformer models
Why this matters
In production, transformer models don't have fixed latency: generation time scales with output length, batch size, and hardware. Without SLA tracking, you'll miss SLA violations in production until customers complain. You need to profile real inference patterns, establish baseline latencies, and implement monitoring that catches violations before they cascade.
Explanation
What it is: A latency SLA (Service Level Agreement) for transformers is a measurable contract that specifies maximum allowable time to generate a response given input constraints. Unlike traditional ML models with fixed input/output shapes, transformers have variable latency because generation is token-by-token autoregressive: you don't know the output length until inference completes.
How it works mechanically: You profile inference by (1) measuring time-to-first-token (TTFT) and time-per-token (TPT) separately, (2) simulating realistic workloads with varying sequence lengths and batch sizes, (3) establishing percentile latencies (p50, p95, p99), and (4) instrumenting production code to track violations. The key insight: total_latency = TTFT + (num_tokens_generated × TPT). TTFT is dominated by model loading and first forward pass; TPT is dominated by KV cache operations. You set SLAs on both, not just total time.
When to use it: Whenever you're deploying transformer inference for real-time applications (chat, summarization, code completion). This becomes critical when you have heterogeneous hardware, batching strategies, or dynamic model serving (multiple models competing for GPU memory).
Analogy
Think of a restaurant kitchen: the first dish from the kitchen (TTFT) involves setup time: gathering ingredients, heating the wok. Subsequent dishes (TPT per token) are faster because the kitchen is warm. Your SLA isn't just 'average dinner time': it's '30 seconds for first plate, then 5 seconds per additional plate.' If you ignore the difference, you'll promise impossible latencies.
Code
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map='auto',
torch_dtype=torch.float16
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def measure_inference_latency(prompt, max_new_tokens=50, num_runs=3):
"""
Measure TTFT and TPT for a single inference run.
Returns dict with latency breakdown.
"""
ttft_times = []
tpt_times = []
for run in range(num_runs):
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
input_length = input_ids.shape[1]
# Measure time-to-first-token (TTFT)
with torch.no_grad():
start = time.perf_counter()
outputs = model.generate(
input_ids,
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True
)
ttft = time.perf_counter() - start
ttft_times.append(ttft)
# Measure time-per-token (TPT) for remaining tokens
with torch.no_grad():
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
start = time.perf_counter()
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True
)
total_time = time.perf_counter() - start
num_generated = outputs.sequences.shape[1] - input_length
if num_generated > 1:
tpt = (total_time - ttft_times[-1]) / (num_generated - 1)
tpt_times.append(tpt)
avg_ttft = np.mean(ttft_times)
avg_tpt = np.mean(tpt_times) if tpt_times else 0
return {
'TTFT_ms': avg_ttft * 1000,
'TPT_ms': avg_tpt * 1000,
'TTFT_p95_ms': np.percentile(ttft_times, 95) * 1000,
'TPT_p95_ms': np.percentile(tpt_times, 95) * 1000 if tpt_times else 0
}
class LatencySLAMonitor:
"""
Track inference latency against SLA thresholds.
"""
def __init__(self, ttft_sla_ms, tpt_sla_ms, p95_margin=0.2):
self.ttft_sla = ttft_sla_ms / 1000
self.tpt_sla = tpt_sla_ms / 1000
self.p95_margin = p95_margin
self.violations = []
self.latencies = []
def record(self, ttft_sec, num_tokens_generated, tpt_sec):
"""
Record a single inference and check against SLA.
"""
violation = {}
if ttft_sec > self.ttft_sla:
violation['ttft_violation'] = {
'measured_ms': ttft_sec * 1000,
'sla_ms': self.ttft_sla * 1000
}
total_tpt_time = tpt_sec * (num_tokens_generated - 1) if num_tokens_generated > 1 else 0
expected_tpt_time = self.tpt_sla * (num_tokens_generated - 1) if num_tokens_generated > 1 else 0
if total_tpt_time > expected_tpt_time:
violation['tpt_violation'] = {
'measured_ms': total_tpt_time * 1000,
'sla_ms': expected_tpt_time * 1000,
'num_tokens': num_tokens_generated
}
if violation:
self.violations.append(violation)
self.latencies.append({
'ttft_ms': ttft_sec * 1000,
'tpt_ms': tpt_sec * 1000,
'num_tokens': num_tokens_generated
})
def report(self):
"""
Generate SLA report.
"""
ttft_vals = [l['ttft_ms'] for l in self.latencies]
tpt_vals = [l['tpt_ms'] for l in self.latencies]
return {
'total_requests': len(self.latencies),
'violations': len(self.violations),
'violation_rate_percent': (len(self.violations) / len(self.latencies) * 100) if self.latencies else 0,
'ttft_p50_ms': np.percentile(ttft_vals, 50),
'ttft_p95_ms': np.percentile(ttft_vals, 95),
'ttft_p99_ms': np.percentile(ttft_vals, 99),
'tpt_p50_ms': np.percentile(tpt_vals, 50),
'tpt_p95_ms': np.percentile(tpt_vals, 95),
'tpt_p99_ms': np.percentile(tpt_vals, 99),
'sample_violations': self.violations[:3]
}
prompt = 'The future of AI is'
latencies = measure_inference_latency(prompt, max_new_tokens=30, num_runs=3)
print('Baseline Latencies (3 runs average):')
for key, val in latencies.items():
print(f' {key}: {val:.2f}')
monitor = LatencySLAMonitor(ttft_sla_ms=500, tpt_sla_ms=25)
for i in range(5):
with torch.no_grad():
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
input_length = input_ids.shape[1]
start = time.perf_counter()
outputs = model.generate(
input_ids,
max_new_tokens=20,
do_sample=False
)
end = time.perf_counter()
num_generated = outputs.shape[1] - input_length
ttft_estimate = latencies['TTFT_ms'] / 1000
tpt_estimate = latencies['TPT_ms'] / 1000
monitor.record(ttft_estimate, num_generated, tpt_estimate)
report = monitor.report()
print('\nSLA Monitor Report:')
for key, val in report.items():
if key != 'sample_violations':
if isinstance(val, float):
print(f' {key}: {val:.2f}')
else:
print(f' {key}: {val}') Baseline Latencies (3 runs average): TTFT_ms: 450.32 TPT_ms: 22.15 TTFT_p95_ms: 468.76 TPT_p95_ms: 24.89 SLA Monitor Report: total_requests: 5 violations: 0 violation_rate_percent: 0.0 ttft_p50_ms: 450.32 ttft_p95_ms: 450.32 ttft_p99_ms: 450.32 tpt_p50_ms: 22.15 tpt_p95_ms: 22.15 tpt_p99_ms: 22.15 sample_violations: []
What just happened?
The code profiled inference latency by separating time-to-first-token (TTFT) from time-per-token (TPT). It measured 3 inference runs to establish baseline latencies, then simulated 5 production requests through a SLA monitor that tracks violations against 500ms TTFT and 25ms TPT thresholds. The monitor accumulated latency statistics (p50, p95, p99) and violation rate. In this run, all requests passed SLA because GPT-2 is small; with larger models or stricter thresholds, violations would appear.
Common gotcha
Most developers measure total latency (start to finish) and set a single SLA threshold. This fails because a 100-token generation and a 5-token generation will have wildly different latencies from the same model. You must separate TTFT (dominated by model loading, one-time) from TPT (dominated by per-token operations, scales linearly). If you set a 2-second SLA and generate 100 tokens at 20ms/token, you'll exceed it even though nothing broke: the math was wrong from the start.
Error recovery
RuntimeError: CUDA out of memoryAttributeError: 'NoneType' object has no attribute 'to'AssertionError in generate()Experienced dev note
The most costly mistake: measuring SLA with a cold model on the first run. The first inference includes model loading, KV cache allocation, CUDA kernel compilation. Always warm up the model with a throwaway inference before profiling. Second gotcha: p95 and p99 latencies matter more than averages in production: a customer sees the slow request, not the average. Monitor percentiles aggressively, especially p99. Third: TTFT changes with input length (longer prompts take longer to process). SLA must account for your median prompt length, not just a 1-word prompt.
Check your understanding
Why does setting a single SLA threshold like 'responses must complete in 2 seconds' often fail for transformer models, and what metric would you track separately to avoid this?
Show answer hint
A correct answer explains that total latency = TTFT (one-time, fixed) + (num_tokens × TPT), so a 100-token response naturally takes longer than a 5-token response even on identical hardware. The answer should identify that you need separate SLA thresholds for TTFT and TPT, not a single 'total time' threshold.