Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Quantify Prediction Uncertainty

Standard segmentation models output a single “best guess” label per voxel. Bayesian models instead sample multiple predictions, letting us quantify uncertainty -- how confident the model is at each voxel. This is critical for clinical applications where knowing what the model does not know matters as much as knowing what it predicts.

This tutorial uses a Bayesian VNet with Pyro-ppl, following the architecture from the kwyk brain labeling project.

PRE_RELEASE = False
import subprocess, sys
try:
    import google.colab  # noqa: F401
    cmd = [sys.executable, "-m", "pip", "install", "-q",
           "nobrainer", "nilearn", "matplotlib", "pyro-ppl"]
    if PRE_RELEASE:
        cmd.insert(4, "--pre")
    subprocess.check_call(cmd)
except ImportError:
    pass

1. Prepare data

import csv
from nobrainer.utils import get_data
from nobrainer.processing.dataset import Dataset

csv_path = get_data()
with open(csv_path) as f:
    reader = csv.reader(f)
    next(reader)
    filepaths = [(row[0], row[1]) for row in reader]

BLOCK_SHAPE = (32, 32, 32)  # 32^3 matches kwyk training config

train_files = filepaths[:3]
eval_feature_path = filepaths[3][0]

ds = (
    Dataset.from_files(train_files, block_shape=BLOCK_SHAPE, n_classes=2)
    .batch(2)
    .binarize()
)

print("Dataset ready:", len(ds.data), "subjects")

2. Train a Bayesian VNet

The bayesian_vnet uses Pyro’s stochastic weight layers (BayesianConv3d) in a U-Net-style encoder-decoder architecture. Each forward pass samples different weights, so repeated predictions give different outputs — this is what enables uncertainty quantification.

The encoder-decoder structure with skip connections works well even on small patches, making it a better choice than MeshNet for demos. The kwyk project (https://github.com/neuronets/kwyk) trained a Bayesian VNet on 11,000+ brain scans — here we use a tiny version.

Key parameters (based on kwyk architecture):

  • base_filters=8: small for CPU demo (kwyk uses 16+)

  • levels=2: 2 encoder/decoder levels (kwyk uses 3-4)

  • prior_type="standard_normal": Gaussian prior on weights

  • Block shape 32^3 matches kwyk’s training configuration

For hyperparameter optimization, use nobrainer research run to automatically explore configurations overnight on a GPU.

from nobrainer.processing.segmentation import Segmentation  # noqa: E402

seg = Segmentation(
    "bayesian_vnet",
    model_args={
        "in_channels": 1,
        "n_classes": 2,
        "base_filters": 8,
        "levels": 2,
        "prior_type": "standard_normal",
    },
)

seg.fit(ds, epochs=3)
print("Bayesian VNet training complete!")

3. Predict with uncertainty

When n_samples > 0, the model runs multiple forward passes with different weight samples and returns three volumes:

  • label: the most frequent (mode) prediction across samples

  • variance: per-voxel variance across samples

  • entropy: Shannon entropy of the predictive distribution

Higher variance or entropy means the model is less certain.

result = seg.predict(
    eval_feature_path,
    block_shape=BLOCK_SHAPE,
    n_samples=3,
)

# Unpack the result tuple
label_img, variance_img, entropy_img = result

print("Label shape:", label_img.shape)
print("Variance shape:", variance_img.shape)
print("Entropy shape:", entropy_img.shape)

4. Examine uncertainty statistics

import numpy as np

var_data = np.asarray(variance_img.dataobj)
ent_data = np.asarray(entropy_img.dataobj)

print("Variance statistics:")
print(f"  Min:  {var_data.min():.6f}")
print(f"  Max:  {var_data.max():.6f}")
print(f"  Mean: {var_data.mean():.6f}")
print(f"  Std:  {var_data.std():.6f}")
print()
print("Entropy statistics:")
print(f"  Min:  {ent_data.min():.6f}")
print(f"  Max:  {ent_data.max():.6f}")
print(f"  Mean: {ent_data.mean():.6f}")
print(f"  Std:  {ent_data.std():.6f}")

5. Interpreting uncertainty

In a well-trained model, you would expect:

  • Low uncertainty in clearly brain or clearly background regions

  • High uncertainty at tissue boundaries (gray/white matter interface)

  • High uncertainty in ambiguous or pathological regions

With our tiny model and 2 epochs of training, the uncertainty values will not be meaningful -- but the workflow is the same for production models.

# Percentage of voxels with above-average variance
high_var_pct = 100 * (var_data > var_data.mean()).sum() / var_data.size
print(f"Voxels with above-average variance: {high_var_pct:.1f}%")

# Percentage of voxels with above-average entropy
high_ent_pct = 100 * (ent_data > ent_data.mean()).sum() / ent_data.size
print(f"Voxels with above-average entropy: {high_ent_pct:.1f}%")

6. Visualize uncertainty maps

import nibabel as nib
import matplotlib.pyplot as plt

feature_vol = np.asarray(nib.load(eval_feature_path).dataobj)
label_data = np.asarray(label_img.dataobj)
mid_slice = feature_vol.shape[2] // 2

plt.figure(figsize=(16, 4))

plt.subplot(1, 4, 1)
plt.imshow(feature_vol[:, :, mid_slice].T, cmap="gray", origin="lower")
plt.title("Input volume")
plt.axis("off")

plt.subplot(1, 4, 2)
plt.imshow(label_data[:, :, mid_slice].T, cmap="gray", origin="lower")
plt.title("Predicted label")
plt.axis("off")

plt.subplot(1, 4, 3)
plt.imshow(var_data[:, :, mid_slice].T, cmap="hot", origin="lower")
plt.title("Variance map")
plt.colorbar()
plt.axis("off")

plt.subplot(1, 4, 4)
plt.imshow(ent_data[:, :, mid_slice].T, cmap="hot", origin="lower")
plt.title("Entropy map")
plt.colorbar()
plt.axis("off")

plt.tight_layout()
plt.show()

Summary

Bayesian models provide per-voxel uncertainty estimates via multiple stochastic forward passes. This is valuable for:

  • Flagging uncertain regions for expert review

  • Active learning (selecting the most informative samples to label)

  • Quality control in automated pipelines

In the next tutorial we will explore synthetic brain generation with GANs.

7. Finding optimal parameters with autoresearch

The model above uses hand-picked hyperparameters for a quick CPU demo. For production models, use nobrainer research run to automatically explore configurations overnight on a GPU:

nobrainer research run \
  --working-dir ./research/bayesian_vnet \
  --model-family bayesian_vnet \
  --max-experiments 15 \
  --budget-hours 8

The research loop will:

  1. Propose hyperparameter changes (via LLM or random grid)

  2. Train, evaluate, keep improvements, revert failures

  3. Save the best model with Croissant-ML metadata

Key hyperparameters to explore:

ParameterRangeWhy it matters
base_filters8, 16, 32Model capacity
levels2, 3, 4Depth of encoder-decoder
prior_type“standard_normal”, “laplace”Weight prior shape
kl_weight1e-5 to 1e-2 (log scale)Balance reconstruction vs regularization
dropout_rate0.0, 0.1, 0.25Additional stochasticity
learning_rate1e-4 to 1e-2 (log scale)Convergence speed
block_shape(32,32,32), (64,64,64)Context per patch

See the kwyk project for reference: trained on 11,000+ subjects with block_shape=32, n_classes=50, lr=0.0001.