Standard segmentation models output a single “best guess” label per voxel. Bayesian models instead sample multiple predictions, letting us quantify uncertainty -- how confident the model is at each voxel. This is critical for clinical applications where knowing what the model does not know matters as much as knowing what it predicts.
This tutorial uses a Bayesian VNet with Pyro-ppl, following the architecture from the kwyk brain labeling project.
PRE_RELEASE = False
import subprocess, sys
try:
import google.colab # noqa: F401
cmd = [sys.executable, "-m", "pip", "install", "-q",
"nobrainer", "nilearn", "matplotlib", "pyro-ppl"]
if PRE_RELEASE:
cmd.insert(4, "--pre")
subprocess.check_call(cmd)
except ImportError:
pass1. Prepare data¶
import csv
from nobrainer.utils import get_data
from nobrainer.processing.dataset import Dataset
csv_path = get_data()
with open(csv_path) as f:
reader = csv.reader(f)
next(reader)
filepaths = [(row[0], row[1]) for row in reader]
BLOCK_SHAPE = (32, 32, 32) # 32^3 matches kwyk training config
train_files = filepaths[:3]
eval_feature_path = filepaths[3][0]
ds = (
Dataset.from_files(train_files, block_shape=BLOCK_SHAPE, n_classes=2)
.batch(2)
.binarize()
)
print("Dataset ready:", len(ds.data), "subjects")2. Train a Bayesian VNet¶
The bayesian_vnet uses Pyro’s stochastic weight layers (BayesianConv3d)
in a U-Net-style encoder-decoder architecture. Each forward pass samples
different weights, so repeated predictions give different outputs — this
is what enables uncertainty quantification.
The encoder-decoder structure with skip connections works well even on
small patches, making it a better choice than MeshNet for demos.
The kwyk project (https://
Key parameters (based on kwyk architecture):
base_filters=8: small for CPU demo (kwyk uses 16+)levels=2: 2 encoder/decoder levels (kwyk uses 3-4)prior_type="standard_normal": Gaussian prior on weightsBlock shape 32^3 matches kwyk’s training configuration
For hyperparameter optimization, use nobrainer research run to
automatically explore configurations overnight on a GPU.
from nobrainer.processing.segmentation import Segmentation # noqa: E402
seg = Segmentation(
"bayesian_vnet",
model_args={
"in_channels": 1,
"n_classes": 2,
"base_filters": 8,
"levels": 2,
"prior_type": "standard_normal",
},
)
seg.fit(ds, epochs=3)
print("Bayesian VNet training complete!")3. Predict with uncertainty¶
When n_samples > 0, the model runs multiple forward passes with
different weight samples and returns three volumes:
label: the most frequent (mode) prediction across samples
variance: per-voxel variance across samples
entropy: Shannon entropy of the predictive distribution
Higher variance or entropy means the model is less certain.
result = seg.predict(
eval_feature_path,
block_shape=BLOCK_SHAPE,
n_samples=3,
)
# Unpack the result tuple
label_img, variance_img, entropy_img = result
print("Label shape:", label_img.shape)
print("Variance shape:", variance_img.shape)
print("Entropy shape:", entropy_img.shape)4. Examine uncertainty statistics¶
import numpy as np
var_data = np.asarray(variance_img.dataobj)
ent_data = np.asarray(entropy_img.dataobj)
print("Variance statistics:")
print(f" Min: {var_data.min():.6f}")
print(f" Max: {var_data.max():.6f}")
print(f" Mean: {var_data.mean():.6f}")
print(f" Std: {var_data.std():.6f}")
print()
print("Entropy statistics:")
print(f" Min: {ent_data.min():.6f}")
print(f" Max: {ent_data.max():.6f}")
print(f" Mean: {ent_data.mean():.6f}")
print(f" Std: {ent_data.std():.6f}")5. Interpreting uncertainty¶
In a well-trained model, you would expect:
Low uncertainty in clearly brain or clearly background regions
High uncertainty at tissue boundaries (gray/white matter interface)
High uncertainty in ambiguous or pathological regions
With our tiny model and 2 epochs of training, the uncertainty values will not be meaningful -- but the workflow is the same for production models.
# Percentage of voxels with above-average variance
high_var_pct = 100 * (var_data > var_data.mean()).sum() / var_data.size
print(f"Voxels with above-average variance: {high_var_pct:.1f}%")
# Percentage of voxels with above-average entropy
high_ent_pct = 100 * (ent_data > ent_data.mean()).sum() / ent_data.size
print(f"Voxels with above-average entropy: {high_ent_pct:.1f}%")6. Visualize uncertainty maps¶
import nibabel as nib
import matplotlib.pyplot as plt
feature_vol = np.asarray(nib.load(eval_feature_path).dataobj)
label_data = np.asarray(label_img.dataobj)
mid_slice = feature_vol.shape[2] // 2
plt.figure(figsize=(16, 4))
plt.subplot(1, 4, 1)
plt.imshow(feature_vol[:, :, mid_slice].T, cmap="gray", origin="lower")
plt.title("Input volume")
plt.axis("off")
plt.subplot(1, 4, 2)
plt.imshow(label_data[:, :, mid_slice].T, cmap="gray", origin="lower")
plt.title("Predicted label")
plt.axis("off")
plt.subplot(1, 4, 3)
plt.imshow(var_data[:, :, mid_slice].T, cmap="hot", origin="lower")
plt.title("Variance map")
plt.colorbar()
plt.axis("off")
plt.subplot(1, 4, 4)
plt.imshow(ent_data[:, :, mid_slice].T, cmap="hot", origin="lower")
plt.title("Entropy map")
plt.colorbar()
plt.axis("off")
plt.tight_layout()
plt.show()Summary¶
Bayesian models provide per-voxel uncertainty estimates via multiple stochastic forward passes. This is valuable for:
Flagging uncertain regions for expert review
Active learning (selecting the most informative samples to label)
Quality control in automated pipelines
In the next tutorial we will explore synthetic brain generation with GANs.
7. Finding optimal parameters with autoresearch¶
The model above uses hand-picked hyperparameters for a quick CPU demo.
For production models, use nobrainer research run to automatically
explore configurations overnight on a GPU:
nobrainer research run \
--working-dir ./research/bayesian_vnet \
--model-family bayesian_vnet \
--max-experiments 15 \
--budget-hours 8The research loop will:
Propose hyperparameter changes (via LLM or random grid)
Train, evaluate, keep improvements, revert failures
Save the best model with Croissant-ML metadata
Key hyperparameters to explore:
| Parameter | Range | Why it matters |
|---|---|---|
base_filters | 8, 16, 32 | Model capacity |
levels | 2, 3, 4 | Depth of encoder-decoder |
prior_type | “standard_normal”, “laplace” | Weight prior shape |
kl_weight | 1e-5 to 1e-2 (log scale) | Balance reconstruction vs regularization |
dropout_rate | 0.0, 0.1, 0.25 | Additional stochasticity |
learning_rate | 1e-4 to 1e-2 (log scale) | Convergence speed |
block_shape | (32,32,32), (64,64,64) | Context per patch |
See the kwyk project for reference: trained on 11,000+ subjects
with block_shape=32, n_classes=50, lr=0.0001.