How to beginner · 3 min read

How to use WandbCallback in Keras

Quick answer
Use WandbCallback from the wandb.keras module to automatically log training metrics, model topology, and system metrics during Keras model training. Initialize wandb.init() before training, then pass WandbCallback() to the callbacks parameter of model.fit().

PREREQUISITES

  • Python 3.8+
  • pip install wandb tensorflow
  • Wandb account and API key (free tier available)
  • Keras (part of TensorFlow 2.x)

Setup

Install wandb and tensorflow packages, then login to wandb CLI to authenticate your API key.

bash
pip install wandb tensorflow
wandb login

Step by step

Initialize a wandb run, build a simple Keras model, and add WandbCallback to model.fit() to track training metrics automatically.

python
import wandb
from wandb.keras import WandbCallback
import tensorflow as tf
from tensorflow.keras import layers
import os

# Initialize a new wandb run
wandb.init(project="keras-wandb-demo")

# Build a simple model
model = tf.keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(32,)),
    layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Generate dummy data
import numpy as np
x_train = np.random.rand(1000, 32)
y_train = np.random.randint(0, 10, 1000)

# Train with WandbCallback
model.fit(
    x_train, y_train,
    epochs=5,
    batch_size=32,
    callbacks=[WandbCallback()]
)

# Finish the wandb run
wandb.finish()
output
Epoch 1/5
32/32 [==============================] - 1s 10ms/step - loss: 2.3025 - accuracy: 0.0980
Epoch 2/5
32/32 [==============================] - 0s 7ms/step - loss: 2.2987 - accuracy: 0.1100
Epoch 3/5
32/32 [==============================] - 0s 7ms/step - loss: 2.2932 - accuracy: 0.1300
Epoch 4/5
32/32 [==============================] - 0s 7ms/step - loss: 2.2860 - accuracy: 0.1530
Epoch 5/5
32/32 [==============================] - 0s 7ms/step - loss: 2.2763 - accuracy: 0.1740
Wandb run finished.

Common variations

You can customize WandbCallback with parameters like monitor to track specific metrics, save_model to save checkpoints, or use it with tf.data datasets. It also supports logging gradients and model topology.

python
WandbCallback(monitor='val_loss', save_model=True, log_weights=True)

Troubleshooting

  • If training metrics do not appear in your Wandb dashboard, ensure wandb.init() is called before training and your API key is correctly set.
  • For slow training, disable log_gradients or reduce logging frequency.
  • If you get authentication errors, run wandb login again or set WANDB_API_KEY environment variable.

Key Takeaways

  • Use WandbCallback in model.fit() to automatically log Keras training metrics to Wandb.
  • Always call wandb.init() before training to start a new run and authenticate.
  • Customize WandbCallback with parameters like monitor and save_model for advanced tracking.
  • Ensure your Wandb API key is set via wandb login or environment variable to avoid authentication issues.
Verified 2026-04
Verify ↗