Long-context fine-tuning
Why this matters
Production LLMs often need to handle documents (contracts, codebases, research papers) longer than their base context window. Fine-tuning with extended context teaches the model to attend meaningfully across longer sequences, dramatically improving real-world performance without additional inference cost.
Explanation
What it is: Long-context fine-tuning trains a language model on sequences longer than its original pre-training context window by using rope scaling (rotary position embedding) or ALiBi adjustments. This teaches the model to generalize its attention mechanisms beyond positions it saw during training.
How it works mechanically: Modern transformers (like Llama 2, Mistral) use rotary position embeddings (RoPE) that encode absolute position. During fine-tuning, you can scale the RoPE frequency base upward, which mathematically stretches the positional encoding space. For example, a model trained on 4k tokens can be fine-tuned with 32k context by increasing the rope scaling factor. The model learns new attention patterns at positions it never saw before, then retains those patterns at inference.
When to use it: Use this when your training corpus contains documents longer than the base context window AND you need those long-context capabilities at inference time. Common cases: RAG systems with multi-document retrieval, legal/financial document analysis, code repositories, research aggregation. Avoid if your documents are naturally short or if you can solve the problem with retrieval instead.
Analogy
Think of a translator trained only on sentences up to 50 words. At inference, you give them a 500-word paragraph. They'll likely lose coherence midway. Fine-tuning with long context is like having them practice with increasingly longer paragraphs during training: they learn how attention and continuity work across longer spans, so 500-word paragraphs start to feel natural.
Code
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
import numpy as np
model_id = "meta-llama/Llama-2-7b-hf"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
original_max_pos = model.config.max_position_embeddings
new_max_context = 32768
scaling_factor = new_max_context / original_max_pos
if hasattr(model.config, 'rope_scaling'):
model.config.rope_scaling = {
"type": "linear",
"factor": scaling_factor
}
else:
model.config.rope_scaling = {
"type": "linear",
"factor": scaling_factor
}
for module in model.modules():
if hasattr(module, 'inv_freq'):
base = 10000
inv_freq = 1.0 / (base ** (torch.arange(0, module.inv_freq.shape[0], 2).float() / module.inv_freq.shape[0]))
module.inv_freq = inv_freq
document_1 = "Historical context: The Treaty of Westphalia (1648) established the modern state system. " * 200
document_2 = "Machine learning fundamentals: Neural networks learn by minimizing loss functions through gradient descent. " * 200
document_3 = "Software architecture patterns guide how large systems are structured. " * 200
training_examples = [
{"text": document_1 + "[END]"},
{"text": document_2 + "[END]"},
{"text": document_3 + "[END]"}
]
from datasets import Dataset
dataset = Dataset.from_dict({"text": [ex["text"] for ex in training_examples]})
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj"]
)
training_args = SFTConfig(
output_dir="./long_context_output",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
logging_steps=1,
learning_rate=2e-4,
warmup_steps=10,
max_seq_length=8192,
packing=False,
optim="adamw_8bit"
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
tokenizer=tokenizer,
dataset_text_field="text"
)
print(f"Original context window: {original_max_pos}")
print(f"New context window: {new_max_context}")
print(f"Rope scaling factor: {scaling_factor:.2f}")
print(f"Model rope_scaling config: {model.config.rope_scaling}")
print(f"Training on {len(dataset)} documents")
print(f"Max sequence length for training: {training_args.max_seq_length}")
print("Ready to train. Call trainer.train() to begin.") Original context window: 4096
New context window: 32768
Rope scaling factor: 8.00
Model rope_scaling config: {'type': 'linear', 'factor': 8.0}
Training on 3 documents
Max sequence length for training: 8192
Ready to train. Call trainer.train() to begin. What just happened?
The code loaded a 7B Llama model with 4k base context, calculated an 8x rope scaling factor to support 32k context, patched the model's rope_scaling config and inv_freq tensors to use linear scaling, created three synthetic long documents (~6k tokens each), wrapped them in a HuggingFace Dataset, configured LoRA for efficient fine-tuning, and set up an SFTTrainer with packing disabled to respect document boundaries. The model is now configured but not trained: calling trainer.train() would begin the actual fine-tuning process.
Common gotcha
Developers often set `max_seq_length` in SFTConfig to the full new context window (32k) but then train on actual documents only 8-16k long. The model uses massive amounts of GPU memory for padding tokens it never sees. Instead, set `max_seq_length` to match your actual training document length or slightly higher: the rope scaling still teaches generalization beyond that window at inference. Also, `packing=True` can violate document boundaries mid-sequence, breaking the long-context training signal; always use `packing=False` for long-context fine-tuning.
Error recovery
RuntimeError: CUDA out of memoryValueError: 'rope_scaling' is not a recognized key in configValidation loss increases despite training loss decreasingExperienced dev note
Long-context fine-tuning does NOT automatically teach a model to retrieve or summarize long contexts: it only extends the geometric space where attention can operate. If your model still performs poorly on long documents after fine-tuning, the issue is likely task-specific capability, not context window. You may need retrieval augmentation or a specialized long-context prompt strategy. Also, rope scaling is most effective when you fine-tune on documents that actually use the extended context meaningfully: fine-tuning on 32k context but training documents that only use the first 4k will give you little benefit.
Check your understanding
You fine-tuned a model with 8x rope scaling on documents that average 6k tokens. At inference, you pass a 24k token document. Will the model perform well, and why or why not?
Show answer hint
A correct answer acknowledges that the model learned attention patterns across its training document distribution (0-6k tokens) under the scaled rope frequency space. At inference with 24k tokens, the model is extrapolating beyond the distribution it learned from, which may work partially due to rope's mathematical smoothness, but it's not guaranteed. Better answer: fine-tune on documents that reach close to your inference length (fine-tune on 20k+ to safely handle 24k inference).