Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Generate Synthetic Brain Volumes

Generative adversarial networks (GANs) can synthesize realistic brain MRI volumes. This is useful for data augmentation, privacy-preserving data sharing, and studying model biases. This tutorial trains a tiny Progressive GAN on downsampled brain data.

PRE_RELEASE = False
import subprocess, sys
try:
    import google.colab  # noqa: F401
    cmd = [sys.executable, "-m", "pip", "install", "-q",
           "nobrainer", "nilearn", "matplotlib", "pytorch-lightning"]
    if PRE_RELEASE:
        cmd.insert(4, "--pre")
    subprocess.check_call(cmd)
except ImportError:
    pass

1. Prepare downsampled training data

Progressive GANs start training at low resolution and grow. For this tutorial we downsample brain volumes to 4x4x4 so training completes in seconds.

import csv
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
import torch
from torch.utils.data import DataLoader, TensorDataset
from nobrainer.utils import get_data

csv_path = get_data()
with open(csv_path) as f:
    reader = csv.reader(f)
    next(reader)
    filepaths = [(row[0], row[1]) for row in reader]

# Downsample volumes to 4^3
target_shape = (4, 4, 4)
volumes = []
for feat_path, _ in filepaths[:5]:
    vol = nib.load(feat_path).get_fdata().astype(np.float32)
    # Compute zoom factors
    factors = tuple(t / s for t, s in zip(target_shape, vol.shape[:3]))
    small = zoom(vol, factors, order=1)
    # Normalize to [0, 1]
    small = (small - small.min()) / (small.max() - small.min() + 1e-8)
    volumes.append(small)

# Stack into a tensor: (N, 1, 4, 4, 4)
data_tensor = torch.tensor(np.stack(volumes)[:, None], dtype=torch.float32)
print("Training tensor shape:", data_tensor.shape)

# Build a DataLoader
train_dataset = TensorDataset(data_tensor)
loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

2. Train the Progressive GAN

We configure a very small model:

  • latent_size=32: small latent vector

  • fmap_base=32, fmap_max=32: minimal feature maps

  • resolution_schedule=[4]: single resolution level (4^3)

  • steps_per_phase=100: few steps per training phase

from nobrainer.processing.generation import Generation

gen = Generation(
    "progressivegan",
    model_args={
        "latent_size": 32,
        "fmap_base": 32,
        "fmap_max": 32,
        "resolution_schedule": [4],
        "steps_per_phase": 100,
    },
)

gen.fit(loader, epochs=50)
print("GAN training complete!")

3. Generate synthetic volumes

synthetic_images = gen.generate(2)

print(f"Generated {len(synthetic_images)} synthetic volumes")
for i, img in enumerate(synthetic_images):
    arr = np.asarray(img.dataobj)
    print(f"  Volume {i}: shape={arr.shape}, "
          f"range=[{arr.min():.3f}, {arr.max():.3f}], "
          f"mean={arr.mean():.3f}")

4. Visualize generated vs. real

At this tiny resolution the images will be blurry blobs, but the workflow is identical for full-resolution training.

try:
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 3, figsize=(10, 3))

    # Real volume (middle slice)
    real = volumes[0]
    axes[0].imshow(real[:, :, 2], cmap="gray", vmin=0, vmax=1)
    axes[0].set_title("Real (4x4x4)")
    axes[0].axis("off")

    # Generated volumes
    for i, img in enumerate(synthetic_images[:2]):
        arr = np.asarray(img.dataobj)
        axes[i + 1].imshow(arr[:, :, 2], cmap="gray")
        axes[i + 1].set_title(f"Generated {i + 1}")
        axes[i + 1].axis("off")

    plt.tight_layout()
    plt.show()
except ImportError:
    print("Install matplotlib for visualization")

5. Generated volumes grid

Generate 4 synthetic volumes and display them as a 2x2 grid of middle axial slices.

try:
    import matplotlib.pyplot as plt

    grid_images = gen.generate(4)

    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    for idx, ax in enumerate(axes.flat):
        arr = np.asarray(grid_images[idx].dataobj)
        mid = arr.shape[2] // 2
        ax.imshow(arr[:, :, mid].T, cmap="gray", origin="lower")
        ax.set_title(f"Generated {idx + 1}")
        ax.axis("off")

    plt.tight_layout()
    plt.show()
except ImportError:
    print("Install matplotlib for visualization")

Notes for production use

For realistic brain generation:

  • Use full-resolution data (e.g., 64^3 or 128^3)

  • Set resolution_schedule=[4, 8, 16, 32, 64] for progressive growing

  • Increase fmap_base and fmap_max (e.g., 512)

  • Train for thousands of steps per phase

  • Use GPU acceleration

Summary

We trained a Progressive GAN to generate synthetic 3D brain volumes. The Generation estimator follows the same sklearn-style pattern: .fit() to train, .generate() to produce new samples. In the next tutorial we will look under the hood at custom training loops.