How to fine-tune BERT for text classification
Quick answer
Fine-tune
BERT for text classification by loading a pretrained bert-base-uncased model with a classification head, preparing your labeled dataset, and training the model using a framework like Hugging Face Transformers and PyTorch. This involves tokenizing input texts, setting up a Trainer or custom training loop, and optimizing the model weights on your classification task.PREREQUISITES
Python 3.8+pip install torch transformers datasetsBasic knowledge of PyTorch and Hugging Face Transformers
Setup
Install the required libraries: transformers for model and tokenizer, datasets for loading datasets, and torch for training.
pip install torch transformers datasets Step by step
This example fine-tunes bert-base-uncased on a text classification dataset using Hugging Face's Trainer API.
import os
from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset
import torch
# Load dataset (e.g., SST2 for sentiment classification)
dataset = load_dataset('glue', 'sst2')
# Load tokenizer and model
model_name = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Tokenize function
def tokenize_function(examples):
return tokenizer(examples['sentence'], padding='max_length', truncation=True)
# Tokenize datasets
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Set format for PyTorch
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
# Training arguments
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=1,
load_best_model_at_end=True,
metric_for_best_model='accuracy'
)
# Define accuracy metric
from datasets import load_metric
metric = load_metric('accuracy')
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = torch.argmax(torch.tensor(logits), dim=-1)
return metric.compute(predictions=predictions, references=labels)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
compute_metrics=compute_metrics
)
# Train model
trainer.train()
# Evaluate model
results = trainer.evaluate()
print(results) output
{'eval_loss': 0.25, 'eval_accuracy': 0.91, 'eval_runtime': 12.34, 'eval_samples_per_second': 50.0} Common variations
- Use
Trainerwith different pretrained BERT variants likebert-large-uncased. - Implement a custom training loop with PyTorch for more control.
- Use mixed precision training with
fp16=TrueinTrainingArgumentsfor faster training on GPUs. - Fine-tune on your own dataset by replacing the
datasets.load_datasetcall with a custom dataset loader.
Troubleshooting
- If you get CUDA out of memory errors, reduce
per_device_train_batch_sizeor use gradient accumulation. - If accuracy is low, check data preprocessing and ensure labels match model output classes.
- For slow training, enable mixed precision with
fp16=Trueand use a GPU. - If tokenizer truncates important text, adjust
max_lengthor use dynamic padding.
Key Takeaways
- Use Hugging Face Transformers and PyTorch to fine-tune BERT efficiently for text classification.
- Tokenize your dataset properly with padding and truncation to fit BERT's input requirements.
- Leverage the Trainer API for simplified training and evaluation with built-in metrics support.
- Adjust batch size and learning rate to balance training speed and model performance.
- Troubleshoot common issues like memory errors by tuning batch size or enabling mixed precision.