Part 2: Uncertainty Quantification
The Learnable Universe | Module 2 | COMP 536
“It is better to be vaguely right than exactly wrong.”
— Carveth Read
Learning Objectives
By the end of this reading, you will be able to:
- Distinguish between epistemic uncertainty (model uncertainty) and aleatoric uncertainty (data noise)
- Implement ensemble methods to quantify neural network uncertainty in JAX/Equinox
- Apply Monte Carlo Dropout as a Bayesian approximation for uncertainty estimation
- Evaluate uncertainty calibration using reliability diagrams and coverage metrics
- Design uncertainty-aware loss functions that incorporate prediction confidence
- Validate that your error bars are honest and scientifically meaningful
- Communicate uncertainty appropriately in scientific contexts
The fundamental difference between engineering and science:
- Engineering: “Does it work?” \(\to\) Accuracy matters most
- Science: “How confident are we?” \(\to\) Uncertainty matters equally
This reading teaches you how to get error bars from neural networks, so you can make fair, scientific comparisons between methods.
Part 1: Why Uncertainty Matters in Science
The Problem with Point Predictions
Standard neural network training gives you a single prediction: \[ \hat{y} = f_{\boldsymbol{\theta}}(\mathbf{x}) \]
What this tells you: “My best guess is \(\hat{y}\)”
What this doesn’t tell you:
- How confident am I?
- Should I trust this prediction?
- Where do I need more training data?
- Is this interpolation or extrapolation?
A Cautionary Tale: The Overconfident Model
Scenario: You train a neural network to predict core radius from initial conditions.
Training data: Clusters with \(N \in [100, 1000]\) stars
Test case: Cluster with \(N = 5000\) stars (extrapolation!)
NN prediction: \(r_{\text{core}} = 0.8\) pc (looks reasonable!)
True value: \(r_{\text{core}} = 2.5\) pc (factor of 3 wrong!)
The problem: The NN had no way to say “I don’t know, this is outside my training distribution.”
With uncertainty quantification:
- NN would predict: \(r_{\text{core}} = 0.8 \pm 2.0\) pc (huge error bars!)
- You’d know not to trust this prediction
- You’d know where to collect more training data
In Module 5 (Project 4), you learned Bayesian inference: \[ p(\theta \,|\, \mathcal{D}) = \frac{p(\mathcal{D} \,|\, \theta) p(\theta)}{p(\mathcal{D})} \]
The posterior \(p(\theta \,|\, \mathcal{D})\) gives you uncertainty about parameters.
For predictions, we want the posterior predictive distribution: \[ p(y_* \,|\, \mathbf{x}_*, \mathcal{D}) = \int p(y_* \,|\, \mathbf{x}_*, \theta) p(\theta \,|\, \mathcal{D}) \, d\theta \]
This is a distribution over predictions, not a single value! Neural networks approximate it using ensembles and dropout.
Two Types of Uncertainty
1. Epistemic Uncertainty (Model Uncertainty)
- Uncertainty about the model itself
- “I don’t have enough data to know the right function”
- Reducible: Collect more training data \(\to\) decreases
Example:
- Few training points \(\to\) High epistemic uncertainty
- Many training points \(\to\) Low epistemic uncertainty
2. Aleatoric Uncertainty (Data Noise)
- Uncertainty inherent in the data
- “Even with perfect model, observations are noisy”
- Irreducible: More data doesn’t help (it’s inherent noise)
Example:
- N-body simulations have stochastic relaxation
- Even identical initial conditions \(\to\) slightly different outcomes
- This is aleatoric uncertainty
| Question | Mostly epistemic | Mostly aleatoric |
|---|---|---|
| What causes it? | Missing knowledge about the learned function | Irreducible variability in the data-generating process |
| What reduces it? | Better coverage, better model class, or more simulations | Better measurement design can help, but more of the same data cannot remove it |
| Final-project symptom | Ensemble members disagree in sparse regions of \((Q_0, a)\) space | Repeated or nearby simulations produce genuinely scattered outcomes |
| How to report it | Show where the emulator is least stable | Include it in the likelihood or uncertainty budget if it affects the observable |
For your final project:
- Epistemic: Do we have enough N-body simulations to learn the pattern?
- Aleatoric: Even with infinite data, relaxation is stochastic
Why distinguish them?
- Epistemic \(\to\) Tells us where to collect more data
- Aleatoric \(\to\) Tells us fundamental limits of predictability
Part 2: Ensemble Methods — The Workhorse Approach
The Core Idea
Train multiple models with different random initializations: \[ f_{\boldsymbol{\theta}_1}(\mathbf{x}), \, f_{\boldsymbol{\theta}_2}(\mathbf{x}), \, \ldots, \, f_{\boldsymbol{\theta}_M}(\mathbf{x}) \]
Ensemble prediction (mean): \[ \bar{f}(\mathbf{x}) = \frac{1}{M} \sum_{m=1}^{M} f_{\boldsymbol{\theta}_m}(\mathbf{x}) \]
Ensemble uncertainty (standard deviation): \[ \sigma_{\text{ensemble}}(\mathbf{x}) = \sqrt{\frac{1}{M-1} \sum_{m=1}^{M} \left(f_{\boldsymbol{\theta}_m}(\mathbf{x}) - \bar{f}(\mathbf{x})\right)^2} \]
Interpretation:
- If all models agree \(\to\) Low uncertainty
- If models disagree \(\to\) High uncertainty
Neural networks are not unique!
Given the same data, different initializations converge to different local minima. These different solutions represent plausible explanations of the data.
The spread of predictions across the ensemble captures model uncertainty (epistemic).
Analogy: Imagine asking 5 different scientists to fit a model to data. If they all agree, you’re confident. If they disagree, you’re uncertain.
Implementation in JAX/Equinox
import jax
import jax.numpy as jnp
import equinox as eqx
from typing import List
class NeuralNetworkEnsemble(eqx.Module):
"""Ensemble of neural networks for uncertainty quantification
Args:
models: List of trained models
"""
models: List[eqx.Module]
def __init__(self, models):
self.models = models
def predict_with_uncertainty(self, x):
"""Predict with epistemic uncertainty
Args:
x: Input (d,) or batch (N, d)
Returns:
mean: Ensemble mean prediction
std: Ensemble standard deviation (epistemic uncertainty)
"""
# Check if batch or single input
is_batch = x.ndim > 1
if is_batch:
# Batch prediction: each model predicts for all samples
predictions = jnp.stack([
jax.vmap(model)(x) for model in self.models
]) # (n_models, batch_size, output_dim)
else:
# Single prediction
predictions = jnp.stack([
model(x) for model in self.models
]) # (n_models, output_dim)
# Ensemble statistics
mean = jnp.mean(predictions, axis=0)
std = jnp.std(predictions, axis=0)
return mean, std
def train_ensemble(
model_fn,
X_train, y_train,
X_val, y_val,
n_models=5,
n_epochs=1000,
batch_size=32,
learning_rate=1e-3,
base_seed=0
):
"""Train ensemble of models with different initializations
Args:
model_fn: Function that creates a new model given a random key
X_train, y_train: Training data
X_val, y_val: Validation data
n_models: Number of models in ensemble
n_epochs: Training epochs per model
batch_size: Batch size
learning_rate: Learning rate
base_seed: Base random seed
Returns:
ensemble: Trained ensemble
"""
import optax
trained_models = []
for i in range(n_models):
print(f"Training model {i+1}/{n_models}")
# Create model with unique initialization
key = jax.random.PRNGKey(base_seed + i)
model = model_fn(key)
# Train model
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
def loss_fn(model, x, y):
y_pred = jax.vmap(model)(x)
return jnp.mean((y - y_pred)**2)
@eqx.filter_jit
def train_step(model, opt_state, x, y):
loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
# Training loop
best_val_loss = jnp.inf
patience_counter = 0
patience = 50
for epoch in range(n_epochs):
# Shuffle and batch
key = jax.random.PRNGKey(epoch + i * 1000)
perm = jax.random.permutation(key, len(X_train))
epoch_loss = 0.0
n_batches = 0
for j in range(0, len(X_train), batch_size):
batch_idx = perm[j:j+batch_size]
x_batch = X_train[batch_idx]
y_batch = y_train[batch_idx]
model, opt_state, loss = train_step(model, opt_state, x_batch, y_batch)
epoch_loss += loss
n_batches += 1
# Validation
val_loss = loss_fn(model, X_val, y_val)
# Early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
break
if epoch % 100 == 0:
print(f" Epoch {epoch}: train_loss={epoch_loss/n_batches:.4f}, val_loss={val_loss:.4f}")
trained_models.append(best_model)
print(f" Best val loss: {best_val_loss:.4f}\n")
return NeuralNetworkEnsemble(trained_models)
# Example usage
if __name__ == "__main__":
from module6_part3_nns import MLP # From Part 1
# Generate dummy data
key = jax.random.PRNGKey(42)
X_train = jax.random.normal(key, (500, 5))
y_train = jnp.sin(jnp.sum(X_train, axis=1, keepdims=True))
X_val = jax.random.normal(key, (100, 5))
y_val = jnp.sin(jnp.sum(X_val, axis=1, keepdims=True))
# Model factory
def create_mlp(key):
return MLP(in_size=5, out_size=1, width=64, depth=3, key=key)
# Train ensemble
ensemble = train_ensemble(
create_mlp,
X_train, y_train,
X_val, y_val,
n_models=5,
n_epochs=500
)
# Predict with uncertainty
x_test = jax.random.normal(key, (5,))
mean, std = ensemble.predict_with_uncertainty(x_test)
print(f"Prediction: {mean[0]:.3f} ± {std[0]:.3f}")Bootstrapping: Ensembles with Data Variation
Standard ensembles: Same data, different initializations (epistemic only)
Bootstrap ensembles: Different data subsets + different initializations (epistemic + some aleatoric)
Bootstrap procedure:
- For each model \(m\):
- Sample \(N\) training points with replacement
- Train model on this bootstrapped dataset
- Ensemble predictions capture both model and data uncertainty
def bootstrap_sample(X, y, key):
"""Sample with replacement"""
n = len(X)
indices = jax.random.choice(key, n, shape=(n,), replace=True)
return X[indices], y[indices]
def train_bootstrap_ensemble(
model_fn,
X_train, y_train,
X_val, y_val,
n_models=5,
**train_kwargs
):
"""Train ensemble with bootstrap sampling"""
trained_models = []
for i in range(n_models):
# Bootstrap sample
key = jax.random.PRNGKey(i)
X_boot, y_boot = bootstrap_sample(X_train, y_train, key)
# Train on bootstrapped data
# (Use train_ensemble code but with X_boot, y_boot)
# ... training logic ...
trained_models.append(model)
return NeuralNetworkEnsemble(trained_models)When to use bootstrap:
- When you want to capture both epistemic and aleatoric uncertainty
- When training data is limited (bootstrap effectively increases diversity)
- When you suspect data sampling affects results
Part 3: Monte Carlo Dropout — Fast Uncertainty Approximation
The Bayesian Interpretation of Dropout
Standard dropout (during training):
- Randomly drop neurons with probability \(p\)
- Forces network to learn robust features
Monte Carlo dropout (at test time):
- Keep dropout active during prediction
- Multiple forward passes \(\to\) distribution of predictions
Theoretical foundation (Gal & Ghahramani, 2016):
- Dropout approximates variational inference in Bayesian neural networks
- Each dropout mask samples from approximate posterior \(q(\boldsymbol{\theta})\)
Standard approach: Dropout is a regularization trick
Bayesian interpretation: Dropout is approximate Bayesian inference!
\[ p(\boldsymbol{\theta} \,|\, \mathcal{D}) \approx q(\boldsymbol{\theta}) \]
where \(q(\boldsymbol{\theta})\) is implicitly defined by the dropout mask distribution.
Practical benefit: Get Bayesian uncertainty without MCMC!
Implementation in Equinox
import jax
import jax.numpy as jnp
import equinox as eqx
class MLPWithDropout(eqx.Module):
"""MLP with dropout for MC uncertainty estimation"""
layers: list
dropouts: list
activation: callable
def __init__(
self,
in_size,
out_size,
width,
depth,
dropout_rate=0.1,
activation=jax.nn.relu,
*,
key
):
keys = jax.random.split(key, depth + 1)
self.layers = []
self.dropouts = []
# Input layer
self.layers.append(eqx.nn.Linear(in_size, width, key=keys[0]))
self.dropouts.append(eqx.nn.Dropout(p=dropout_rate))
# Hidden layers
for i in range(depth - 1):
self.layers.append(eqx.nn.Linear(width, width, key=keys[i+1]))
self.dropouts.append(eqx.nn.Dropout(p=dropout_rate))
# Output layer (no dropout)
self.layers.append(eqx.nn.Linear(width, out_size, key=keys[-1]))
self.activation = activation
def __call__(self, x, *, key=None):
"""Forward pass with optional dropout
Args:
x: Input
key: Random key for dropout (if None, no dropout)
"""
if key is not None:
keys = jax.random.split(key, len(self.dropouts))
else:
keys = [None] * len(self.dropouts)
# Hidden layers with activation and dropout
for layer, dropout, dropout_key in zip(self.layers[:-1], self.dropouts, keys):
x = layer(x)
x = self.activation(x)
if dropout_key is not None:
x = dropout(x, key=dropout_key)
# Output layer (no activation, no dropout)
x = self.layers[-1](x)
return x
def mc_dropout_predict(model, x, n_samples=100, *, key):
"""Monte Carlo dropout prediction with uncertainty
Args:
model: Model with dropout
x: Input (d,) or batch (N, d)
n_samples: Number of MC samples
key: Random key
Returns:
mean: Mean prediction
std: Standard deviation (epistemic uncertainty)
"""
keys = jax.random.split(key, n_samples)
# Check if batch
is_batch = x.ndim > 1
if is_batch:
# Multiple MC samples for each input in batch
predictions = jnp.stack([
jax.vmap(lambda xi: model(xi, key=k))(x)
for k in keys
]) # (n_samples, batch_size, output_dim)
else:
# Multiple MC samples for single input
predictions = jnp.stack([
model(x, key=k) for k in keys
]) # (n_samples, output_dim)
mean = jnp.mean(predictions, axis=0)
std = jnp.std(predictions, axis=0)
return mean, std
# Example usage
if __name__ == "__main__":
key = jax.random.PRNGKey(0)
# Create model with dropout
model = MLPWithDropout(
in_size=5,
out_size=1,
width=64,
depth=3,
dropout_rate=0.2,
key=key
)
# Train model (with dropout active)
# ... training code similar to ensemble ...
# Predict with MC dropout
x_test = jax.random.normal(key, (5,))
mean, std = mc_dropout_predict(model, x_test, n_samples=100, key=key)
print(f"MC Dropout Prediction: {mean[0]:.3f} ± {std[0]:.3f}")Ensemble vs MC Dropout
| Aspect | Ensemble | MC Dropout |
|---|---|---|
| Training cost | High (train \(M\) models) | Low (train 1 model) |
| Inference cost | Low (1 forward pass/model) | Medium (\(T\) forward passes) |
| Uncertainty type | Epistemic | Epistemic (approximate) |
| Theoretical foundation | Bootstrap / Bayesian model averaging | Variational inference |
| Accuracy | Generally better | Good approximation |
| Memory | High (\(M\) models in memory) | Low (1 model) |
Recommendation:
- Ensemble: When you have compute budget and want best uncertainty estimates
- MC Dropout: When you need fast uncertainty at inference time
For your final project:
- Use ensembles (5 – 10 models) for primary results
- Can compare with MC dropout as ablation study
Part 4: Calibration — Are Your Error Bars Honest?
The Calibration Problem
Well-calibrated uncertainty:
- If you predict \(y = 10 \pm 2\) (95% interval: [6, 14])
- Then 95% of true values should fall in [6, 14]
Poorly calibrated:
- Overconfident: Error bars too small (true values often outside)
- Underconfident: Error bars too large (true values always inside)
Calibrated uncertainty is essential for science:
If your error bars are wrong:
- Can’t properly weight observations
- Can’t combine predictions with other measurements
- Can’t make reliable scientific conclusions
- Can’t know when to trust predictions
Example: If you claim 95% confidence intervals but only 60% of true values fall inside, you’re systematically overconfident. This breaks downstream inference!
Calibration Metrics
Expected Calibration Error (ECE):
- Sort predictions by confidence
- Bin into \(K\) bins (e.g., \(K=10\))
- For each bin: compare predicted confidence vs empirical accuracy
- Weighted average of differences
Practical calibration for regression:
For each prediction \(\hat{y}_i \pm \sigma_i\), compute standardized residual: \[ z_i = \frac{y_i - \hat{y}_i}{\sigma_i} \]
If well-calibrated: \(z_i \sim \mathcal{N}(0, 1)\)
Tests:
- Mean: \(\bar{z} \approx 0\) (no systematic bias)
- Std: \(\text{std}(z) \approx 1\) (correct variance)
- Shapiro-Wilk: Test if \(z\) is normally distributed
- QQ-plot: Visual check of normality
Implementation
import jax.numpy as jnp
from scipy import stats
def compute_calibration_metrics(y_true, y_pred, y_std):
"""Compute calibration metrics for regression
Args:
y_true: True values (N,)
y_pred: Predicted means (N,)
y_std: Predicted standard deviations (N,)
Returns:
metrics: Dictionary of calibration metrics
"""
# Standardized residuals
z = (y_true - y_pred) / y_std
# Basic statistics
mean_z = float(jnp.mean(z))
std_z = float(jnp.std(z))
# Coverage at different confidence levels
coverage_68 = float(jnp.mean(jnp.abs(z) < 1.0)) # 68% (1 sigma)
coverage_95 = float(jnp.mean(jnp.abs(z) < 1.96)) # 95% (2 sigma)
coverage_99 = float(jnp.mean(jnp.abs(z) < 2.58)) # 99% (3 sigma)
# Shapiro-Wilk test for normality
shapiro_stat, shapiro_p = stats.shapiro(z)
return {
"mean_z": mean_z, # Should be ~0
"std_z": std_z, # Should be ~1
"coverage_68": coverage_68, # Should be ~0.68
"coverage_95": coverage_95, # Should be ~0.95
"coverage_99": coverage_99, # Should be ~0.99
"shapiro_statistic": float(shapiro_stat),
"shapiro_p_value": float(shapiro_p),
"is_normal": shapiro_p > 0.05
}
def plot_calibration(y_true, y_pred, y_std):
"""Visualize calibration
Creates:
1. Reliability diagram (predicted vs empirical coverage)
2. QQ-plot (standardized residuals vs normal)
3. Histogram of standardized residuals
"""
import matplotlib.pyplot as plt
z = (y_true - y_pred) / y_std
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# 1. Reliability diagram
confidence_levels = jnp.linspace(0, 3, 30) # 0 to 3 sigma
empirical_coverage = []
for conf in confidence_levels:
coverage = jnp.mean(jnp.abs(z) < conf)
empirical_coverage.append(coverage)
# Theoretical coverage for normal distribution
theoretical_coverage = 2 * stats.norm.cdf(confidence_levels) - 1
axes[0].plot(confidence_levels, empirical_coverage, 'b-', label='Empirical')
axes[0].plot(confidence_levels, theoretical_coverage, 'r--', label='Ideal (Normal)')
axes[0].set_xlabel('Confidence Level (σ)')
axes[0].set_ylabel('Empirical Coverage')
axes[0].set_title('Reliability Diagram')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# 2. QQ-plot
stats.probplot(z, dist="norm", plot=axes[1])
axes[1].set_title('Q-Q Plot')
axes[1].grid(True, alpha=0.3)
# 3. Histogram
axes[2].hist(z, bins=30, density=True, alpha=0.7, edgecolor='black')
x_range = jnp.linspace(-4, 4, 100)
axes[2].plot(x_range, stats.norm.pdf(x_range), 'r-', linewidth=2, label='N(0,1)')
axes[2].set_xlabel('Standardized Residuals (z)')
axes[2].set_ylabel('Density')
axes[2].set_title('Distribution of Standardized Residuals')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
return fig
# Example usage
if __name__ == "__main__":
# Generate dummy data
key = jax.random.PRNGKey(42)
n_test = 200
y_true = jax.random.normal(key, (n_test,))
y_pred = y_true + 0.1 * jax.random.normal(key, (n_test,)) # Small bias
y_std = 1.0 + 0.2 * jax.random.uniform(key, (n_test,)) # Varying uncertainty
# Compute metrics
metrics = compute_calibration_metrics(y_true, y_pred, y_std)
print("Calibration Metrics:")
for key, value in metrics.items():
print(f" {key}: {value:.3f}")
# Visualize
fig = plot_calibration(y_true, y_pred, y_std)
plt.show()Calibration Strategies
If your model is overconfident (error bars too small):
Temperature scaling: Scale uncertainties by factor \(T > 1\) \[ \sigma_{\text{calibrated}} = T \cdot \sigma_{\text{predicted}} \] Choose \(T\) on validation set to match empirical coverage
Train with calibration loss: \[ \mathcal{L} = \text{MSE} + \lambda \left(\frac{1}{N}\sum_i z_i^2 - 1\right)^2 \] Penalizes deviation from unit variance of standardized residuals
Increase ensemble size: More models \(\to\) better uncertainty estimates
If your model is underconfident (error bars too large):
- Reduce regularization: Less dropout, smaller weight decay
- Use smaller ensemble: Fewer models may give tighter estimates
- Check for dataset shift: Validation set may be harder than test set
Part 5: Uncertainty-Aware Training
Incorporating Uncertainty in Loss Functions
Standard loss (MSE): \[ \mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2 \]
Uncertainty-aware loss (negative log-likelihood): \[ \mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \left[\frac{(y_i - \hat{y}_i)^2}{2\sigma_i^2} + \log \sigma_i\right] \]
Interpretation:
- First term: Penalizes prediction error (weighted by uncertainty)
- Second term: Penalizes large uncertainties (prevents trivial solution \(\sigma_i \to \infty\))
Why this works: This is the negative log-likelihood under Gaussian noise!
Predicting Aleatoric Uncertainty
Train network to output both mean and variance: \[ (\mu, \sigma^2) = f_{\boldsymbol{\theta}}(\mathbf{x}) \]
Architecture:
class NetworkWithUncertainty(eqx.Module):
"""Network that predicts mean and aleatoric uncertainty"""
shared_layers: eqx.nn.MLP
mean_head: eqx.nn.Linear
log_var_head: eqx.nn.Linear # Predict log(σ²) for numerical stability
def __init__(self, in_size, hidden_size, out_size, *, key):
key1, key2, key3 = jax.random.split(key, 3)
# Shared feature extractor
self.shared_layers = eqx.nn.MLP(
in_size=in_size,
out_size=hidden_size,
width_size=hidden_size,
depth=3,
activation=jax.nn.relu,
key=key1
)
# Separate heads for mean and variance
self.mean_head = eqx.nn.Linear(hidden_size, out_size, key=key2)
self.log_var_head = eqx.nn.Linear(hidden_size, out_size, key=key3)
def __call__(self, x):
"""Predict mean and variance
Returns:
mean: Predicted mean (out_size,)
var: Predicted variance (out_size,)
"""
features = self.shared_layers(x)
mean = self.mean_head(features)
log_var = self.log_var_head(features)
var = jnp.exp(log_var) # Ensure positive
return mean, var
def gaussian_nll_loss(model, x, y):
"""Negative log-likelihood loss
Args:
model: NetworkWithUncertainty
x: Inputs (batch_size, in_size)
y: Targets (batch_size, out_size)
"""
def single_loss(xi, yi):
mean, var = model(xi)
# NLL = 0.5 * log(2π) + 0.5 * log(var) + 0.5 * (y - mean)^2 / var
# Omitting constant 0.5 * log(2π)
nll = 0.5 * jnp.log(var) + 0.5 * (yi - mean)**2 / var
return jnp.sum(nll) # Sum over output dimensions
return jnp.mean(jax.vmap(single_loss)(x, y))Now your network outputs:
- \(\mu\): Best prediction (epistemic captured by ensembles)
- \(\sigma^2\): Aleatoric uncertainty (data noise, irreducible)
Total uncertainty (for ensemble): \[ \sigma_{\text{total}}^2 = \underbrace{\sigma_{\text{epistemic}}^2}_{\text{ensemble variance}} + \underbrace{\bar{\sigma}_{\text{aleatoric}}^2}_{\text{mean predicted variance}} \]
Part 6: Practical Considerations for Your Final Project
Computational Budget vs Uncertainty Quality
Ensemble sizes:
- \(M = 5\): Minimum for reasonable uncertainty
- \(M = 10\): Good balance (recommended)
- \(M = 20\): Diminishing returns
- \(M > 50\): Rarely worth it
Rule of thumb: Uncertainty converges as \(\sim 1/\sqrt{M}\)
MC Dropout samples:
- \(T = 50\): Fast, rough uncertainty
- \(T = 100\): Balanced (recommended)
- \(T = 500\): High quality, slower
When to Use Which Method
Use Ensembles when:
- You have computational resources
- You need best uncertainty estimates
- Final production model (worth the cost)
- Your final project report (primary method)
Use MC Dropout when:
- Quick prototyping
- Limited memory (can’t store many models)
- Real-time applications
- Ablation study in your project
Use Predicted Variance when:
- Aleatoric uncertainty matters (e.g., noisy simulations)
- Combined with ensembles for total uncertainty
- You want to learn spatially-varying noise
Reporting Uncertainty in Your Final Project
For each method you use, report:
- Prediction accuracy: RMSE, MAE, \(R^2\)
- Calibration metrics: Mean/std of standardized residuals, coverage
- Uncertainty statistics: Mean uncertainty, min/max uncertainty
- Visualizations:
- Predictions with error bars vs true values
- Reliability diagrams
- Uncertainty vs distance to training data
Example visualization:
def plot_predictions_with_uncertainty(y_true, y_pred, y_std, title="Predictions"):
"""Plot predictions with error bars"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Predicted vs True
axes[0].errorbar(y_true, y_pred, yerr=2*y_std, fmt='o', alpha=0.5, capsize=3)
axes[0].plot([y_true.min(), y_true.max()],
[y_true.min(), y_true.max()], 'r--', label='Perfect')
axes[0].set_xlabel('True Value')
axes[0].set_ylabel('Predicted Value')
axes[0].set_title(f'{title}: Predictions with 95% CI')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Residuals vs Uncertainty
residuals = jnp.abs(y_true - y_pred)
axes[1].scatter(y_std, residuals, alpha=0.5)
axes[1].plot([0, y_std.max()], [0, 2*y_std.max()], 'r--', label='2σ line')
axes[1].set_xlabel('Predicted Uncertainty (σ)')
axes[1].set_ylabel('Absolute Residual')
axes[1].set_title('Uncertainty vs Error')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
return figPart 7: Advanced Topics
Conformal Prediction
Problem: Ensembles give uncertainty, but not guaranteed coverage
Solution: Conformal prediction provides distribution-free coverage guarantees
Key idea:
- Compute residuals on calibration set: \(|y_i - \hat{y}_i|\)
- Find quantile \(q_\alpha\) such that \(\alpha\) fraction of residuals are below \(q_\alpha\)
- For new prediction: \(\hat{y} \pm q_\alpha\) has coverage \(\geq 1-\alpha\)
Advantage: Works for any model (even black box!)
def conformal_prediction_interval(y_cal, y_pred_cal, y_pred_test, alpha=0.05):
"""Compute conformal prediction intervals
Args:
y_cal: True values on calibration set (N_cal,)
y_pred_cal: Predictions on calibration set (N_cal,)
y_pred_test: Predictions on test set (N_test,)
alpha: Desired error rate (0.05 for 95% intervals)
Returns:
lower: Lower bounds (N_test,)
upper: Upper bounds (N_test,)
"""
# Compute absolute residuals on calibration set
residuals = jnp.abs(y_cal - y_pred_cal)
# Find (1-alpha) quantile
n = len(residuals)
q_level = jnp.ceil((n + 1) * (1 - alpha)) / n
q = jnp.quantile(residuals, q_level)
# Prediction intervals
lower = y_pred_test - q
upper = y_pred_test + q
return lower, upperWhen to use: When you need guaranteed coverage (e.g., safety-critical applications)
Bayesian Neural Networks
Full Bayesian approach: Put prior on all weights \[ p(\mathbf{W} \,|\, \mathcal{D}) \propto p(\mathcal{D} \,|\, \mathbf{W}) p(\mathbf{W}) \]
Challenge: Posterior is intractable (millions of parameters!)
Solutions:
- Variational inference: Approximate posterior with tractable family
- MCMC: Sample weights (expensive!)
- Laplace approximation: Gaussian approximation around MAP estimate
Practical note: These are advanced techniques beyond this course, but worth knowing they exist.
Conceptual Checkpoints
Epistemic vs Aleatoric: You predict \(r_{\text{core}} = 1.0 \pm 0.2\) pc. You collect 1000 more training simulations. Will the uncertainty decrease? Which type of uncertainty does this primarily capture?
Ensemble Size: You train ensembles of size \(M = 2, 5, 10, 20, 50\). Plot (mentally) how uncertainty estimates change. At what point do you see diminishing returns?
Calibration: Your NN ensemble predicts 95% confidence intervals, but when you check, only 70% of true values fall inside. Is your model overconfident or underconfident? How would you fix it?
Extrapolation: You train on \(N \in [100, 1000]\) and test on \(N = 5000\). Your ensemble gives \(r_{\text{core}} = 0.8 \pm 0.1\) pc (tight uncertainty). Should you trust this? Why might ensembles be overconfident in extrapolation?
MC Dropout: Explain why keeping dropout active at test time approximates Bayesian inference. How is this different from standard dropout as regularization?
Connection to Module 5: How is ensemble uncertainty related to the posterior distribution \(p(\theta \,|\, \mathcal{D})\) from Bayesian inference? What role do different initializations play?
Summary and Looking Forward
What You’ve Learned
You now understand:
- Why uncertainty matters for science (trustworthiness, data collection)
- Two types: Epistemic (model) vs aleatoric (data noise)
- Ensemble methods: Train multiple models, capture disagreement
- MC Dropout: Fast Bayesian approximation
- Calibration: Checking if error bars are honest
- Uncertainty-aware training: Networks that predict their own uncertainty
- Practical implementation: Complete JAX/Equinox code
For Your Final Project
Your Report Should Include:
- Uncertainty Analysis:
- Calibration plots for all methods
- Coverage statistics (68%, 95%, 99%)
- Discussion of when predictions are trustworthy
- Method Comparison:
- Not just “Method A is 10% more accurate”
- But “Method A is more accurate but Method B is better calibrated”
- Trade-offs: accuracy vs uncertainty quality
- Scientific Recommendations:
- When to trust your model’s predictions
- Where more training data would help (high epistemic uncertainty)
- Fundamental limits of predictability (aleatoric uncertainty)
The difference between prediction and science:
- Prediction: “The answer is 1.23”
- Science: “The answer is 1.23 \(\pm\) 0.15, calibrated on held-out data”
Why this matters:
In engineering, being wrong is costly. In science, being overconfident is fatal.
If you claim narrow error bars that don’t reflect true uncertainty:
- Other scientists can’t properly combine your results with theirs
- Downstream inference (parameter estimation, model selection) breaks
- You can’t identify where more data is needed
As computational scientists, we have a responsibility — not just to predict accurately, but to quantify our ignorance honestly.
That’s what makes machine learning scientific machine learning.