How to beginner · 3 min read

How to use Weights and Biases with PyTorch

Quick answer
Use the wandb Python package to initialize a project and log training metrics in your PyTorch training loop. Initialize with wandb.init(), log metrics using wandb.log(), and optionally save model checkpoints for experiment tracking.

PREREQUISITES

  • Python 3.8+
  • pip install torch wandb
  • Weights and Biases account (free tier available)
  • Set environment variable <code>WANDB_API_KEY</code> with your API key

Setup

Install the required packages and configure your environment to use Weights and Biases (wandb) with PyTorch.

bash
pip install torch wandb

Step by step

This example demonstrates a simple PyTorch training loop integrated with wandb for logging loss and accuracy metrics.

python
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb

# Initialize wandb project
wandb.init(project="pytorch-demo", entity="your-entity")

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

# Dummy dataset
inputs = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # Calculate accuracy
    _, preds = torch.max(outputs, 1)
    accuracy = (preds == labels).float().mean().item()

    # Log metrics to wandb
    wandb.log({"epoch": epoch + 1, "loss": loss.item(), "accuracy": accuracy})

print("Training complete. Check your W&B dashboard for logs.")
output
Training complete. Check your W&B dashboard for logs.

Common variations

  • Use wandb.watch(model) to log gradients and model topology.
  • Log images, histograms, or custom charts with wandb.log().
  • Integrate with PyTorch Lightning or other frameworks using built-in callbacks.
  • Use asynchronous logging or sync with wandb.finish() at the end of training.
python
import wandb

# Watch model to log gradients
wandb.watch(model, log="all")

# Log a custom metric example
wandb.log({"custom_metric": 0.95})

# Finish run
wandb.finish()

Troubleshooting

  • If you see wandb: ERROR API key not found, set your API key with export WANDB_API_KEY=your_key or run wandb login.
  • For slow logging, reduce logging frequency or disable gradient logging with wandb.watch(model, log=None).
  • If runs do not appear on the dashboard, check your internet connection and firewall settings.

Key Takeaways

  • Initialize wandb with wandb.init() before training to start tracking.
  • Log metrics and artifacts inside the training loop using wandb.log() for real-time monitoring.
  • Use wandb.watch() to track model gradients and parameters automatically.
  • Set your WANDB_API_KEY environment variable to authenticate your runs.
  • Troubleshoot common issues by verifying API key setup and network connectivity.
Verified 2026-04
Verify ↗