How to use WandbCallback in Keras
Quick answer
Use WandbCallback from the wandb.keras module to automatically log training metrics, model topology, and system metrics during Keras model training. Initialize wandb.init() before training, then pass WandbCallback() to the callbacks parameter of model.fit().
PREREQUISITES
Python 3.8+pip install wandb tensorflowWandb account and API key (free tier available)Keras (part of TensorFlow 2.x)
Setup
Install wandb and tensorflow packages, then login to wandb CLI to authenticate your API key.
pip install wandb tensorflow
wandb login Step by step
Initialize a wandb run, build a simple Keras model, and add WandbCallback to model.fit() to track training metrics automatically.
import wandb
from wandb.keras import WandbCallback
import tensorflow as tf
from tensorflow.keras import layers
import os
# Initialize a new wandb run
wandb.init(project="keras-wandb-demo")
# Build a simple model
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)),
layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Generate dummy data
import numpy as np
x_train = np.random.rand(1000, 32)
y_train = np.random.randint(0, 10, 1000)
# Train with WandbCallback
model.fit(
x_train, y_train,
epochs=5,
batch_size=32,
callbacks=[WandbCallback()]
)
# Finish the wandb run
wandb.finish() output
Epoch 1/5 32/32 [==============================] - 1s 10ms/step - loss: 2.3025 - accuracy: 0.0980 Epoch 2/5 32/32 [==============================] - 0s 7ms/step - loss: 2.2987 - accuracy: 0.1100 Epoch 3/5 32/32 [==============================] - 0s 7ms/step - loss: 2.2932 - accuracy: 0.1300 Epoch 4/5 32/32 [==============================] - 0s 7ms/step - loss: 2.2860 - accuracy: 0.1530 Epoch 5/5 32/32 [==============================] - 0s 7ms/step - loss: 2.2763 - accuracy: 0.1740 Wandb run finished.
Common variations
You can customize WandbCallback with parameters like monitor to track specific metrics, save_model to save checkpoints, or use it with tf.data datasets. It also supports logging gradients and model topology.
WandbCallback(monitor='val_loss', save_model=True, log_weights=True) Troubleshooting
- If training metrics do not appear in your Wandb dashboard, ensure wandb.init() is called before training and your API key is correctly set.
- For slow training, disable log_gradients or reduce logging frequency.
- If you get authentication errors, run wandb login again or set WANDB_API_KEY environment variable.
Key Takeaways
- Use WandbCallback in model.fit() to automatically log Keras training metrics to Wandb.
- Always call wandb.init() before training to start a new run and authenticate.
- Customize WandbCallback with parameters like monitor and save_model for advanced tracking.
- Ensure your Wandb API key is set via wandb login or environment variable to avoid authentication issues.