Part 1: Neural Networks and Scientific Emulation

The Learnable Universe | Module 2 | COMP 536

Author

Anna Rosen

“What I cannot create, I do not understand.”

— Richard Feynman (found on his blackboard at the time of his death, 1988)

Learning Objectives

By the end of this module, you will be able to:

ImportantRequired Path and Optional Depth

Required for the final project: Understand the emulator hierarchy, train/validation/test discipline, the 2-input MLP emulator, normalization with training-set statistics, baseline comparisons, held-out residual diagnostics, ensemble uncertainty, calibration basics, and how the emulator becomes the forward model inside a likelihood.

Optional depth: The theorem details, extended ecosystem examples, MC dropout, predicted variance heads, conformal prediction, Bayesian neural networks, and advanced inference discussion help you see the bigger picture. Use them after the baseline emulator is working and validated.


NoteYour Roadmap Through This Module

Core Question: Can a validated neural-network emulator replace repeated N-body simulations inside Bayesian inference for cluster initial conditions without destroying scientific trust?

The Big Picture: You’ve built N-body simulations in Project 2, and you’re now rebuilding that simulator in JAX for the final project. You’ve also built Bayesian inference engines in Project 4. Now you’ll learn to build fast approximations to expensive computations. This is the key to the final project: training a neural network to predict N-body simulation summaries, assigning uncertainty to those predictions, and using the emulator inside a likelihood for inference.

Section 1.1: The Emulation Problem Why we need fast surrogates for expensive simulations, and what exactly the emulator is approximating.

Section 1.2: From Linear Regression to Neural Networks The intellectual progression: linear models \(\to\) feature engineering \(\to\) learned features.

Section 1.3: The Computational Neuron The mathematical building blocks. Neurons, activations, and why nonlinearity matters.

Section 1.4: Layering Neurons into Networks Architecture design: hidden layers, parameter counting, and the expressiveness tradeoff.

Section 1.5: The Universal Approximation Theorem Why neural networks can (in principle) approximate any function. The profound idea and its limits.

Section 1.6: Forward Propagation How input becomes output. The computational graph perspective you know from the JAX module.

Section 1.7: Training as Optimization Loss functions, gradient descent, and the connection to likelihood maximization.

Section 1.8: Weight Initialization — Breaking Symmetry Why random initialization matters and how it enables ensembles.

Section 1.9: Practical Training Normalization, learning rates, convergence, multi-output handling, and debugging.

Section 1.10: Uncertainty via Ensembles Why single networks are dangerous, and how ensembles estimate one part of epistemic uncertainty.

Section 1.11: The JAX Ecosystem — Equinox and Optax Professional tools for building and training neural networks.

Section 1.12: From Emulator to Inference Constructing the physical-unit likelihood and connecting your trained network to NumPyro for Bayesian parameter recovery.

Section 1.13: Synthesis What you’ve learned and how it all connects.

TipVocabulary: AI, ML, Deep Learning, and Emulation

These words often get blurred together, so let’s separate them clearly:

  • Artificial Intelligence (AI): The broad umbrella for systems that perform tasks we associate with intelligence.
  • Machine Learning (ML): A subset of AI in which a model learns patterns from data instead of being programmed rule by rule.
  • Deep Learning: A subset of ML that uses multi-layer neural networks.
  • Scientific emulation: A specific ML use case in which we train a fast surrogate model to approximate an expensive scientific computation.

This reading is not trying to make you a general AI engineer. It is teaching one focused and useful workflow: how to build a trustworthy neural-network emulator for a scientific simulator, evaluate it honestly, and use it for inference.


1.1: The Emulation Problem

Priority: Essential

Why Emulators Matter

Emulator (or surrogate model): A fast approximation to an expensive computation. Trained on input-output pairs from the expensive model, then used in place of it.

Consider your N-body simulator. Each run takes seconds to minutes. But scientific inference requires thousands of forward-model evaluations. In direct inference, each likelihood evaluation may call the physical simulator. In emulator-based inference, each likelihood evaluation calls a trained surrogate. The question is whether the surrogate is accurate enough, calibrated enough, and limited to a trustworthy domain.

The computational bottleneck: Your final project asks you to infer what initial conditions \((Q_0, a)\) produced a given cluster outcome. If one simulation takes about 30 seconds, then 2000 likelihood evaluations would cost

\[ \underbrace{2000}_{\text{likelihood evaluations}} \times \underbrace{30\,\mathrm{s}}_{\text{per simulation}} \approx 17\,\mathrm{hours}. \]

If each likelihood evaluation averages over 50 stochastic cluster realizations, the cost becomes

\[ \underbrace{2000}_{\text{likelihood evaluations}} \times \underbrace{50}_{\text{realizations per likelihood}} \times \underbrace{30\,\mathrm{s}}_{\text{per simulation}} \approx 35\,\mathrm{days}. \]

The exact numbers depend on your simulator, hardware, timestep, and inference setup. The lesson is the scaling: direct simulation inside inference can become infeasible quickly.

The emulator solution: Train a neural network on a modest, validated simulation ensemble. The network learns a fast approximation to the simulator’s summary-statistic map. Now inference can become an interactive analysis step, provided the held-out validation evidence supports trusting the emulator.

TipPredict – Try – Explain: Why Emulate?

Predict: Which part of the final project becomes impossible if each likelihood evaluation reruns the N-body simulator?

Try: Write the two workflows side by side: direct inference calls the simulator; emulator-based inference calls the trained surrogate.

Explain: The emulator is not replacing physics. It is replacing repeated expensive evaluations only after the simulator and held-out emulator tests are credible.

What Exactly Is Being Emulated?

The final project is not “train a neural network because neural networks are powerful.” The scientific problem is an inverse problem. We want to infer which initial cluster conditions,

\[ \mathbf{x} = (Q_0, a), \]

could have produced a measured or simulated cluster outcome,

\[ \mathbf{y} = \begin{bmatrix} f_{\rm bound} \\ \sigma_v \\ r_h \end{bmatrix}. \]

Here \(f_{\rm bound}\) is dimensionless, \(\sigma_v\) has velocity units, and \(r_h\) has length units. That unit mismatch is not cosmetic. It is why we normalize outputs for training and why we must be careful when building a physical likelihood later.

Let \(\Phi_T\) denote the numerical N-body evolution from \(t = 0\) to final time \(T\), and let \(s\) denote the summary-statistic extractor applied to the final particle state. The physical forward model is the composition

\[ \mathbf{y} = s\!\left(\Phi_T(\mathbf{x})\right). \]

Let’s unpack the pieces:

  • \(\mathbf{x} = (Q_0, a)\) is the physical input vector.
  • \(\Phi_T\) is the numerical flow map produced by your JAX-native N-body integrator.
  • \(s\) extracts summary statistics such as bound fraction, velocity dispersion, and half-mass radius.
  • \(\mathbf{y}\) is the summary vector used for training and inference.

The emulator approximates this composition:

\[ f_{\boldsymbol{\theta}}(\mathbf{x}) \approx s\!\left(\Phi_T(\mathbf{x})\right). \]

What this equation is really saying: the network is not learning Newton’s laws directly. It is learning a compact approximation to the simulator’s input-output map over the parameter domain you sampled. Whether that approximation captures scientifically meaningful structure must be demonstrated by validation, residual analysis, and failure-mode checks.

Show code
flowchart TD
    x["Physical inputs\n(Q0, a)"] --> phi["N-body evolution Phi_T"]
    phi --> zT["Final particle state"]
    zT --> s["Summary extractor s"]
    s --> y["Summary vector\n(f_bound, sigma_v, r_h)"]
    y --> D["Training data"]
    D --> emu["MLP emulator f_theta(Q0, a)"]
    emu --> res["Residuals\ny - f_theta(x)"]
    res --> like["Likelihood\np(y_obs | Q0, a)"]
    like --> post["Posterior\np(Q0, a | y_obs)"]

flowchart TD
    x["Physical inputs\n(Q0, a)"] --> phi["N-body evolution Phi_T"]
    phi --> zT["Final particle state"]
    zT --> s["Summary extractor s"]
    s --> y["Summary vector\n(f_bound, sigma_v, r_h)"]
    y --> D["Training data"]
    D --> emu["MLP emulator f_theta(Q0, a)"]
    emu --> res["Residuals\ny - f_theta(x)"]
    res --> like["Likelihood\np(y_obs | Q0, a)"]
    like --> post["Posterior\np(Q0, a | y_obs)"]

NoteConnection to Module 1: Models as Compression

Remember Module 1’s insight that models compress information? Summary statistics compress high-dimensional simulation outputs (particle positions and velocities) into meaningful numbers (\(f_{\rm bound}\), \(\sigma_v\), \(r_h\)). An emulator goes further — it compresses the simulator’s summary-statistic map into a set of neural-network weights.

This is learned compression: the network discovers a useful representation of the mapping over the sampled domain. Your approximately 4500 network parameters encode a fitted approximation to how initial virial ratio and scale radius map to the chosen cluster summaries. They do not encode the fundamental equations of motion in the way your simulator does.

Emulators in Modern Astrophysics

This isn’t a toy problem. Emulation is transforming computational science:

Field Expensive forward model Emulator predicts Inference target
Cosmology N-body or hydrodynamic simulations matter power spectrum or summary statistics cosmological parameters
Stellar evolution MESA stellar tracks luminosity, radius, lifetime, or \(T_{\rm eff}\) age, mass, metallicity
Galaxy formation large hydrodynamic simulations galaxy properties from model parameters feedback or cosmological parameters
Final project cluster N-body simulation \((f_{\rm bound}, \sigma_v, r_h)\) \((Q_0, a)\)

The common pattern is not “replace physics with ML.” The common pattern is: trust a validated simulator, train a fast approximation to its summary outputs, and use that approximation only where validation supports it.

ImportantMinimum Emulator Workflow

If you feel overwhelmed by the vocabulary, keep this compact recipe in mind:

  1. Build and validate your JAX-native simulator first.
  2. Generate a modest training set over a clearly defined parameter range.
  3. Compute summary statistics that will be your emulator targets.
  4. Normalize inputs and outputs using training-set statistics only.
  5. Train one small MLP as a baseline emulator.
  6. Compare against mean, linear, and simple nearest-neighbor or interpolation-style baselines where feasible.
  7. Add an ensemble for one component of epistemic uncertainty and check calibration.
  8. Build a physical-unit likelihood and plug the emulator into NumPyro.

That is the core final-project lane. Everything else in this module is here to help you understand and defend those steps.


1.2: From Linear Regression to Neural Networks

Priority: Important

Before diving into neural network details, let’s see how they emerge naturally from ideas you already know.

The Progression of Function Approximation

Feature engineering: Manually constructing input transformations (like \(x^2\), \(\log x\), \(x_1 x_2\)) to capture nonlinear relationships. Traditional approach before neural networks automated this process.

Level 1: Linear Regression

The simplest model: \(\hat{y} = \mathbf{w}^T \mathbf{x} + b\)

You find weights \(\mathbf{w}\) and bias \(b\) that minimize squared error. This works beautifully when the true relationship is linear — but nature rarely cooperates.

Level 2: Polynomial/Basis Regression

To capture nonlinearity, you engineer features: \(\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \ldots\)

This is still linear in the parameters (the \(w_i\)), but nonlinear in the input. The problem: you must choose which features to include. For your emulator, should you use \(Q_0^2\)? \(\log a\)? \(Q_0 \times a\)? You’d have to guess.

Level 3: Neural Networks

Here’s the key insight: what if we learned the features?

A neural network with one hidden layer computes: \[\hat{y} = \mathbf{w}_{\rm out}^T \underbrace{\sigma(\mathbf{W}_{\rm hidden} \mathbf{x} + \mathbf{b}_{\rm hidden})}_{\text{learned features}} + b_{\rm out}\]

The hidden layer creates learned nonlinear features of the input. The output layer then does linear regression on these features. Training optimizes both the feature extraction and the final regression simultaneously.

NoteThree Levels of Function Approximation
Approach What the human chooses What the algorithm learns
Linear regression The raw inputs Weights on those inputs
Polynomial regression Features such as \(x^2\) and \(x^3\) Weights on chosen features
Neural network Architecture and training data Nonlinear features and weights together

The neural network does not remove human judgment. You still choose the data, architecture, loss, and validation evidence.

ImportantThe Profound Shift

Traditional statistics: Human chooses features \(\to\) algorithm finds weights Neural networks: Algorithm finds features and weights

This is why neural networks work across such diverse problems — they discover task-appropriate representations automatically. The same architecture that predicts cluster dynamics can (with different trained weights) recognize galaxies in images or denoise spectra.

NoteVisual Companion: What Is a Neural Network?

3Blue1Brown’s “But what is a neural network?” introduces the same basic architecture through handwritten-digit recognition: input activations, hidden layers, weights, biases, and output activations.

Lecture framing: 3Blue1Brown uses handwritten digits. Our problem is not classification; it is scientific emulation. As you watch, translate pixels to \((Q_0, a)\) and digit outputs to \((f_{\rm bound}, \sigma_v, r_h)\).

As you watch, translate the example into the final-project emulator:

  • Digit pixels become physical inputs \((Q_0, a)\).
  • Digit probabilities become summary-statistic predictions \((f_{\rm bound}, \sigma_v, r_h)\).
  • Image features become learned nonlinear features of the simulation design space.

The example is classification; your emulator is regression. The computational structure is the same kind of layered function.

Credit: 3Blue1Brown / Grant Sanderson

Where Does Your Emulator Fit?

Your task: learn the mapping \((Q_0, a) \to (f_{\rm bound}, \sigma_v, r_h)\).

Is this relationship linear? Almost certainly not. The physics of gravitational collapse, violent relaxation, and stellar escape involves complex nonlinear dynamics. You could try to engineer features based on virial theorem scaling relations — but why guess when you can learn?

A neural network will discover whatever nonlinear transformations of \((Q_0, a)\) are useful for predicting cluster outcomes. If it turns out linear features suffice, the network can learn that too (it includes linear models as a special case).


1.3: The Computational Neuron

Priority: Essential

Anatomy of a Neuron

Artificial neuron: A computational unit that computes a weighted sum of inputs, adds a bias, and applies a nonlinear activation function. Inspired by (but vastly simpler than) biological neurons.

The artificial neuron is the atomic unit of neural networks. It takes \(d\) inputs \(x_1, x_2, \ldots, x_d\) and produces one output:

\[a = \sigma\left(\sum_{j=1}^{d} w_j x_j + b\right) = \sigma(\mathbf{w}^T \mathbf{x} + b)\]

where: - \(\mathbf{w} = (w_1, \ldots, w_d)\) are weights controlling how much each input matters - \(b\) is a bias allowing the neuron to shift its activation threshold - \(\sigma\) is a nonlinear activation function

Think of it as “how strongly should this neuron fire, given these inputs?” The weights determine sensitivity to each input; the bias sets the baseline; the activation function introduces nonlinearity.

Pre-activation: The quantity \(z = \mathbf{w}^T \mathbf{x} + b\) before the nonlinearity acts. Represents “how much this neuron wants to activate” before thresholding.

The quantity \(z = \mathbf{w}^T \mathbf{x} + b\) is called the pre-activation — the neuron’s response before the nonlinear squashing. The activation function then decides the actual output: \(a = \sigma(z)\).

NoteNeuron Computation Map

Input features \(x_1, x_2, \ldots, x_d\) are multiplied by weights \(w_1, w_2, \ldots, w_d\), summed with a bias \(b\), and passed through an activation: \[ z = \sum_i w_i x_i + b, \qquad a = \sigma(z). \]

The important split is linear pre-activation followed by nonlinear activation.

Why Nonlinearity Is Essential

ImportantThe Collapse of Linear Networks

Without activation functions, a neuron computes a linear transformation: \(a = \mathbf{w}^T \mathbf{x} + b\).

Stack two linear layers: \[\text{Layer 1: } \mathbf{h} = \mathbf{W}_1 \mathbf{x} + \mathbf{b}_1\] \[\text{Layer 2: } \mathbf{y} = \mathbf{W}_2 \mathbf{h} + \mathbf{b}_2\]

Substitute: \[\mathbf{y} = \mathbf{W}_2 (\mathbf{W}_1 \mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2 = \underbrace{(\mathbf{W}_2 \mathbf{W}_1)}_{\text{one matrix}} \mathbf{x} + \underbrace{(\mathbf{W}_2 \mathbf{b}_1 + \mathbf{b}_2)}_{\text{one bias}}\]

No matter how many layers, you’d collapse to a single linear transformation. Deep networks would be no more powerful than shallow ones!

The nonlinearity \(\sigma\) prevents this collapse. It’s what allows networks to represent complex, nonlinear relationships — like how a cluster’s bound fraction depends on initial virial ratio through the physics of violent relaxation and escape.

Activation Functions

Several activation functions are commonly used:

ReLU (Rectified Linear Unit) — the modern default: \[\text{ReLU}(z) = \max(0, z) = \begin{cases} z & \text{if } z > 0 \\ 0 & \text{if } z \leq 0 \end{cases}\]

Beautifully simple: pass positive values unchanged, zero out negative values. Computationally cheap and works remarkably well. The gradient is 1 for \(z > 0\) and 0 for \(z < 0\), which helps training (no “vanishing gradients” for positive inputs).

Sigmoid (the “Fermi function” you know from statistical mechanics): \[\sigma(z) = \frac{1}{1 + e^{-z}}\]

Squashes outputs to \((0, 1)\). Useful for probabilities, but can cause vanishing gradients in deep networks (derivative is small when \(|z|\) is large).

Tanh: \[\tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}}\]

Squashes to \((-1, 1)\). Zero-centered, which can help optimization. Same vanishing gradient concern as sigmoid.

NoteActivation Function Comparison
Activation Range Useful intuition Main caution
ReLU \([0, \infty)\) Cheap, sparse, strong default for hidden layers Zero gradient for inactive units
Sigmoid \((0, 1)\) Probability-like output Saturates for large \(\lvert z \rvert\)
Tanh \((-1, 1)\) Zero-centered sigmoid-like shape Also saturates for large \(\lvert z \rvert\)

Historical note: Early neural networks (1980s–90s) used sigmoid and tanh. ReLU was popularized around 2010 and revolutionized deep learning by alleviating the vanishing gradient problem. Sometimes the simplest solutions win.

For your emulator: Use ReLU for hidden layers. It’s the standard choice for regression and works well in practice.


1.4: Layering Neurons into Networks

Priority: Essential

The Multi-Layer Perceptron

Multi-Layer Perceptron (MLP): A neural network architecture consisting of an input layer, one or more hidden layers, and an output layer. Also called a “feed-forward” network because information flows in one direction (no loops).

A single neuron has limited expressiveness. The power comes from layering neurons into a network:

Input layer: Your raw features \(\mathbf{x} \in \mathbb{R}^{d_{\text{in}}}\). Not really “neurons” — just the data entering the network.

Hidden layer(s): Where the magic happens. Each layer transforms its input through weights, biases, and activations. A hidden layer with \(n\) neurons computes:

\[\mathbf{h} = \sigma(\mathbf{W}\mathbf{x} + \mathbf{b})\]

where \(\mathbf{W} \in \mathbb{R}^{n \times d_{\text{in}}}\) is the weight matrix, \(\mathbf{b} \in \mathbb{R}^n\) is the bias vector, and \(\sigma\) applies element-wise.

Output layer: Produces final predictions \(\hat{\mathbf{y}} \in \mathbb{R}^{d_{\text{out}}}\). For regression, typically no activation (linear output).

Your Emulator Architecture

For the final project, you’ll build: - Input: 2 features \((Q_0, a)\) - Hidden 1: 64 neurons, ReLU activation - Hidden 2: 64 neurons, ReLU activation - Output: 3 predictions \((f_{\text{bound}}, \sigma_v, r_h)\), linear (no activation)

Units reminder: Ensure \(\sigma_v\) uses units consistent with your JAX simulator (typically km/s or your chosen internal N-body units). The scale radius \(a\) and half-mass radius \(r_h\) should use the same length convention throughout the pipeline.

Show code
flowchart LR
    subgraph Input["Input Layer (2)"]
        Q0["Q0"]
        a["a"]
    end

    subgraph H1["Hidden Layer 1\n64 neurons, ReLU"]
        h1["h1 layer 1"]
        h2["..."]
        h3["h64 layer 1"]
    end

    subgraph H2["Hidden Layer 2\n64 neurons, ReLU"]
        h4["h1 layer 2"]
        h5["..."]
        h6["h64 layer 2"]
    end

    subgraph Output["Output Layer (3)\nLinear"]
        f["f_bound"]
        s["sigma_v"]
        r["r_h"]
    end

    Q0 --> H1
    a --> H1
    H1 --> H2
    H2 --> Output

flowchart LR
    subgraph Input["Input Layer (2)"]
        Q0["Q0"]
        a["a"]
    end

    subgraph H1["Hidden Layer 1\n64 neurons, ReLU"]
        h1["h1 layer 1"]
        h2["..."]
        h3["h64 layer 1"]
    end

    subgraph H2["Hidden Layer 2\n64 neurons, ReLU"]
        h4["h1 layer 2"]
        h5["..."]
        h6["h64 layer 2"]
    end

    subgraph Output["Output Layer (3)\nLinear"]
        f["f_bound"]
        s["sigma_v"]
        r["r_h"]
    end

    Q0 --> H1
    a --> H1
    H1 --> H2
    H2 --> Output

Counting Parameters

Learnable parameters: The weights and biases that training optimizes. The “knobs” the network adjusts to fit data.

How many learnable parameters does your network have? This matters for understanding model complexity.

Layer 1 (input \(\to\) hidden 1):

  • Weight matrix \(\mathbf{W}_1 \in \mathbb{R}^{64 \times 2}\): \(64 \times 2 = 128\) weights
  • Bias vector \(\mathbf{b}_1 \in \mathbb{R}^{64}\): 64 biases
  • Subtotal: 192 parameters

Layer 2 (hidden 1 \(\to\) hidden 2):

  • Weight matrix \(\mathbf{W}_2 \in \mathbb{R}^{64 \times 64}\): \(64 \times 64 = 4096\) weights
  • Bias vector \(\mathbf{b}_2 \in \mathbb{R}^{64}\): 64 biases
  • Subtotal: 4,160 parameters

Layer 3 (hidden 2 \(\to\) output):

  • Weight matrix \(\mathbf{W}_3 \in \mathbb{R}^{3 \times 64}\): \(3 \times 64 = 192\) weights
  • Bias vector \(\mathbf{b}_3 \in \mathbb{R}^{3}\): 3 biases
  • Subtotal: 195 parameters

Total: 4,547 parameters

WarningConnection to Module 1: Degrees of Freedom and Overfitting

With 4,547 parameters and only ~100 training examples, alarm bells should ring. You have ~45\(\times\) more parameters than data points!

In Module 1, you learned that model complexity must balance against data availability. Too many degrees of freedom and you fit noise, not signal — the model memorizes training data but fails on new inputs.

Why doesn’t this automatically doom neural networks? Several factors help, but none of them remove the need for validation:

  1. Implicit regularization: Gradient descent with early stopping tends to find “simple” solutions
  2. Low-dimensional input domain: The map from \((Q_0, a)\) to three summaries is much smaller than an image-recognition problem
  3. Expected smoothness: Nearby initial conditions often produce nearby summary statistics, except near real dynamical transitions
  4. Architecture restraint: This MLP is small relative to modern deep networks
  5. Validation: We monitor validation performance to detect overfitting while keeping the test set untouched
  6. Normalization and early stopping: These keep optimization better behaved

Still, with limited data, you should expect some overfitting risk. The emulator is scientifically valid only inside the parameter domain covered by the training design, and only where held-out residuals and calibration checks support it.


1.5: The Universal Approximation Theorem

Priority: Important

The Theoretical Foundation

Universal Approximation Theorem: First proved by Cybenko (1989) for sigmoid activations; later extended to ReLU and other activations by Hornik, Leshno, and others.

Here’s a remarkable mathematical result that justifies using neural networks as flexible function approximators:

Theorem (Universal Approximation): Let \(f: K \to \mathbb{R}\) be any continuous function defined on a compact set \(K \subset \mathbb{R}^d\). For any \(\epsilon > 0\), there exists a neural network \(\hat{f}\) with a single hidden layer and sufficiently many neurons such that:

\[\sup_{\mathbf{x} \in K} |f(\mathbf{x}) - \hat{f}(\mathbf{x})| < \epsilon\]

In words: any continuous function can be approximated arbitrarily well by a neural network with one hidden layer, if you use enough neurons.

What this means for emulation: Your N-body simulator defines a continuous function from initial conditions to summary statistics (assuming the summary statistics vary continuously with inputs, which they do for your problem). The theorem guarantees that some neural network can approximate this mapping to any desired accuracy.

The Intuition: Building Functions from Simple Pieces

How can a network approximate arbitrary functions?

For sigmoid activations: Each neuron creates a “soft step” — a smooth transition from low to high output. By positioning many soft steps at different locations with different heights, you can approximate any shape. It’s like building a sculpture from many small clay pieces, each contributing a local bump or dip.

For ReLU activations: Each neuron creates a “hinge” — a bend in the function at some location. A network with \(n\) ReLU neurons can create a piecewise linear function with up to \(n\) hinges. With enough hinges, piecewise linear functions can approximate any continuous curve.

NoteReLU Approximation Intuition

Picture a smooth target curve, then approximate it with more and more straight-line segments. A few ReLU neurons make a crude piecewise-linear sketch; more neurons add more hinges; enough well-placed hinges can track the curve closely over the training domain.

The phrase “over the training domain” matters. This theorem does not promise sensible extrapolation outside the sampled range.

The Crucial Caveats

WarningWhat the Theorem Does NOT Guarantee

1. How many neurons you need: The theorem says “sufficiently many” but the number could be astronomically large for complex functions. It’s an existence theorem, not a construction.

2. How to find the right weights: Knowing a good network exists doesn’t tell you how to train it. Optimization might get stuck in poor local minima.

3. That training data suffices: The theorem assumes you can evaluate the true function anywhere. In practice, you only have finite samples. Generalization from samples to the full function is a separate challenge.

4. Behavior outside the training region: The theorem applies on a compact set \(K\). Outside this region, the network extrapolates with no guarantees.

Universal approximation \(\neq\) universal learning. The theorem speaks to expressiveness (what networks can represent), not learnability (what we can find from data).

For the final project, the theorem only tells us that a neural network could approximate the simulator summary map on the sampled \((Q_0, a)\) domain. It does not tell us that your training set is dense enough, that Adam found the right approximation, that the emulator respects cluster dynamics outside the training range, or that the resulting posterior will be calibrated.


1.6: Forward Propagation

Priority: Essential

Computing Predictions

Forward propagation (or forward pass): The process of computing a network’s output from its input by applying each layer’s transformation sequentially.

Forward propagation is how we compute predictions: apply each layer’s transformation in sequence, feeding outputs forward.

For your 2-hidden-layer network, the forward pass computes:

\[\mathbf{z}_1 = \mathbf{W}_1 \mathbf{x} + \mathbf{b}_1 \quad \text{(pre-activation, layer 1)}\] \[\mathbf{h}_1 = \text{ReLU}(\mathbf{z}_1) \quad \text{(activation, layer 1)}\] \[\mathbf{z}_2 = \mathbf{W}_2 \mathbf{h}_1 + \mathbf{b}_2 \quad \text{(pre-activation, layer 2)}\] \[\mathbf{h}_2 = \text{ReLU}(\mathbf{z}_2) \quad \text{(activation, layer 2)}\] \[\hat{\mathbf{y}} = \mathbf{W}_3 \mathbf{h}_2 + \mathbf{b}_3 \quad \text{(output, linear)}\]

Each layer takes the previous layer’s output, applies a linear transformation (matrix multiply + bias), and passes through an activation (except the final layer).

NoteWhy No Activation on the Output Layer?

For regression tasks (predicting continuous values like \(f_{\text{bound}}\), \(\sigma_v\), \(r_h\)), the output layer should be linear — no activation function.

Activations like sigmoid squash outputs to bounded ranges \((0, 1)\). If your targets can span ranges like \(r_h \in [20, 150]\) AU, a sigmoid would artificially constrain predictions. ReLU would prevent negative predictions, which might be fine for \(f_{\rm bound} \geq 0\) but problematic if residuals can go negative.

For classification tasks, you’d use sigmoid (binary) or softmax (multiclass) to produce probabilities. But your emulator does regression, so keep the output linear.

The Computational Graph Perspective

Computational graph: A directed acyclic graph representing how computations compose. Nodes are operations or values; edges show data flow. JAX builds these automatically when tracing functions.

Forward propagation defines a computational graph — exactly the structure JAX uses for automatic differentiation.

Show code
flowchart TB
    x["x input"] --> z1["z1 = W1 x + b1"]
    W1["W1"] --> z1
    b1["b1"] --> z1
    z1 --> h1["h1 = ReLU(z1)"]
    h1 --> z2["z2 = W2 h1 + b2"]
    W2["W2"] --> z2
    b2["b2"] --> z2
    z2 --> h2["h2 = ReLU(z2)"]
    h2 --> y["y_hat = W3 h2 + b3"]
    W3["W3"] --> y
    b3["b3"] --> y
    y --> L["Loss L"]
    ytrue["y_true"] --> L

flowchart TB
    x["x input"] --> z1["z1 = W1 x + b1"]
    W1["W1"] --> z1
    b1["b1"] --> z1
    z1 --> h1["h1 = ReLU(z1)"]
    h1 --> z2["z2 = W2 h1 + b2"]
    W2["W2"] --> z2
    b2["b2"] --> z2
    z2 --> h2["h2 = ReLU(z2)"]
    h2 --> y["y_hat = W3 h2 + b3"]
    W3["W3"] --> y
    b3["b3"] --> y
    y --> L["Loss L"]
    ytrue["y_true"] --> L

This graph shows exactly how the output \(\hat{\mathbf{y}}\) depends on the input \(\mathbf{x}\) and all parameters \(\boldsymbol{\theta} = (\mathbf{W}_1, \mathbf{b}_1, \mathbf{W}_2, \mathbf{b}_2, \mathbf{W}_3, \mathbf{b}_3)\).

When we train, we need gradients \(\nabla_{\boldsymbol{\theta}} \mathcal{L}\). The computational graph tells JAX exactly how to compute them via the chain rule — that’s what jax.grad does automatically.


Conceptual Checkpoint: Architecture Foundations

Before moving to training, verify your understanding:

ImportantSelf-Assessment
  1. Nonlinearity: If you removed all ReLU activations from your network, what class of functions could it represent? Why is this limiting?

  2. Parameter count: A network has layers of sizes 10 \(\to\) 50 \(\to\) 50 \(\to\) 5. How many total parameters (weights + biases)?

  3. Forward pass: Given input \(\mathbf{x} = [0.5, 1.0]^T\) and first layer weights \(\mathbf{W}_1 = \begin{pmatrix} 1 & -1 \\ 0 & 2 \end{pmatrix}\) with bias \(\mathbf{b}_1 = [0, -1]^T\), compute \(\mathbf{h}_1\) (the first hidden layer output, with ReLU).

  4. Architecture choice: Why does your emulator use ReLU for hidden layers but no activation for the output layer?

  1. Only linear functions \(\hat{\mathbf{y}} = \mathbf{A}\mathbf{x} + \mathbf{c}\). Limiting because most physical relationships are nonlinear.

  2. Layer 1: \(50 \times 10 + 50 = 550\). Layer 2: \(50 \times 50 + 50 = 2550\). Layer 3: \(5 \times 50 + 5 = 255\). Total: 3,355 parameters.

  3. \(\mathbf{z}_1 = \begin{pmatrix} 1 & -1 \\ 0 & 2 \end{pmatrix} \begin{pmatrix} 0.5 \\ 1.0 \end{pmatrix} + \begin{pmatrix} 0 \\ -1 \end{pmatrix} = \begin{pmatrix} -0.5 \\ 1.0 \end{pmatrix}\). Then \(\mathbf{h}_1 = \text{ReLU}(\mathbf{z}_1) = \begin{pmatrix} 0 \\ 1.0 \end{pmatrix}\).

  4. ReLU provides nonlinearity so hidden layers can learn complex features. Output layer is linear because we’re doing regression — predictions should be unconstrained real numbers.


1.7: Training as Optimization

Priority: Essential

The Loss Function

Loss function (or cost function, objective function): A scalar function measuring how wrong the network’s predictions are. Training minimizes this function.

Training requires a measure of “how wrong” the network is. For regression, the standard choice is mean squared error (MSE):

\[ \mathcal{L}(\boldsymbol{\theta}) = \frac{1}{N} \sum_{i=1}^{N} \left\| \hat{\mathbf{y}}_i - \mathbf{y}_i^{\rm true} \right\|^2. \]

where \(\boldsymbol{\theta}\) denotes all network parameters, \(N\) is the number of training examples, and \(\hat{\mathbf{y}}_i = f_{\text{network}}(\mathbf{x}_i; \boldsymbol{\theta})\) is the prediction for input \(\mathbf{x}_i\).

The loss is a function of the parameters: different weights produce different predictions, which produce different errors. Training finds parameters that minimize the loss.

NoteTraining Loss, Validation Metric, Scientific Likelihood

These three quantities are related, but they are not interchangeable:

Quantity Used for Example
Training loss fitting network weights MSE on normalized outputs
Validation metric detecting generalization failure RMSE on held-out simulations
Scientific likelihood inference over physical parameters Gaussian likelihood in physical output units

In Bayesian inference, you maximize or sample from a posterior. In neural-network training, you minimize a loss. These are connected when the loss can be derived from a likelihood, but the interpretation depends on the assumptions.

For a multi-output emulator, a simple diagonal Gaussian residual model says

\[ y_{ik} = f_{\boldsymbol{\theta},k}(\mathbf{x}_i) + \epsilon_{ik}, \qquad \epsilon_{ik} \sim \mathcal{N}(0, \sigma_k^2), \]

where \(k\) indexes the output component: \(f_{\rm bound}\), \(\sigma_v\), or \(r_h\).

Under independent residuals across examples and output components,

\[ p(\mathbf{Y} \,|\, \mathbf{X}, \boldsymbol{\theta}, \boldsymbol{\sigma}) = \prod_{i=1}^{N} \prod_{k=1}^{3} \mathcal{N} \left( y_{ik} \,\middle|\, f_{\boldsymbol{\theta},k}(\mathbf{x}_i), \sigma_k^2 \right). \]

If all outputs are normalized and given equal unit variance for training, minimizing MSE is equivalent to maximizing this Gaussian likelihood under a fixed equal-variance residual model in normalized space. That is useful for optimization. It is not automatically the same as the physical likelihood you will use in NumPyro.

The physical likelihood must compare physical-unit predictions to physical-unit observations, with uncertainties in the same units as the outputs.

Gradient Descent

Gradient descent: An iterative optimization algorithm that updates parameters in the direction opposite to the gradient (direction of steepest increase), thereby decreasing the objective function.

We want \(\boldsymbol{\theta}^* = \arg\min_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta})\). Gradient descent iteratively steps toward the minimum:

\[\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta \nabla_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta}_t)\]

where \(\eta > 0\) is the learning rate — how large a step we take.

Learning rate (\(\eta\)): The step size in gradient descent. Controls the tradeoff between speed (large \(\eta\)) and stability (small \(\eta\)). Choosing it well is crucial.

The gradient \(\nabla_{\boldsymbol{\theta}} \mathcal{L}\) points in the direction of steepest increase in loss. We move in the opposite direction to decrease the loss.

TipPredict – Try – Explain: Learning Rate

Predict: If the learning rate is too large, what will the loss curve look like?

Try: Sketch three loss curves: too large, too small, and reasonable. Label what you would change in each case.

Explain: The learning rate is not just a speed knob. It controls whether the optimizer can settle into a useful minimum at all.

NoteVisual Companion: Gradient Descent

3Blue1Brown’s “Gradient descent, how neural networks learn” is a good visual bridge from this equation to training behavior. The digit-recognition example has many more inputs than our emulator, but the optimization story is the same: define a loss, compute which direction increases it, then step the parameters in the opposite direction.

For this course, keep one extra scientific habit in mind: decreasing training loss is not enough. You still need validation loss, held-out evaluation, and a baseline comparison before trusting the emulator.

Credit: 3Blue1Brown / Grant Sanderson
NoteConnection to Module 5: Optimization vs. Sampling

You’ve now seen three approaches to finding good parameters:

Approach What It Finds Module
Analytical Exact optimum (when possible) Linear regression
Sampling (MCMC) Full posterior distribution Module 5, Project 4
Optimization (GD) Single point estimate (mode) This reading

Gradient descent finds one answer — the parameter values at the (local) minimum. MCMC explores the full posterior — the entire distribution of plausible parameters.

For inference, we often want the full posterior. For training emulators, we usually just want a good point estimate. But we’ll recover uncertainty via ensembles (Section 1.10).

Backpropagation = Chain Rule = Autodiff

Backpropagation: The algorithm for computing gradients of the loss with respect to all network parameters. Mathematically, it’s the chain rule applied systematically through the computational graph.

Backpropagation computes gradients of the loss with respect to all network parameters. The key insight: it’s just the chain rule applied systematically.

Consider how the loss depends on first-layer weights:

\[\frac{\partial \mathcal{L}}{\partial \mathbf{W}_1} = \frac{\partial \mathcal{L}}{\partial \hat{\mathbf{y}}} \cdot \frac{\partial \hat{\mathbf{y}}}{\partial \mathbf{h}_2} \cdot \frac{\partial \mathbf{h}_2}{\partial \mathbf{h}_1} \cdot \frac{\partial \mathbf{h}_1}{\partial \mathbf{W}_1}\]

The chain rule propagates gradients backward from the loss through each layer to the parameters.

NoteVisual Companion: What Backpropagation Is Doing

3Blue1Brown’s “What is backpropagation really doing?” is the best place to build intuition before staring at code. Watch for the central idea: each weight and bias gets a sensitivity score telling us how changing it would change the loss.

In your emulator, those sensitivities answer a concrete question: how should the weights of the MLP change so that predictions of \(f_{\rm bound}\), \(\sigma_v\), and \(r_h\) better match the simulator outputs?

Credit: 3Blue1Brown / Grant Sanderson
TipConnection to the JAX Module: You Already Know This!

In the JAX module, you learned that JAX computes gradients via automatic differentiation — specifically, reverse-mode autodiff. Backpropagation is exactly this: reverse-mode autodiff applied to neural networks.

When you write:

def loss_fn(params):
    pred = model(params, x)
    return jnp.mean((pred - y_true)**2)

gradients = jax.grad(loss_fn)(params)

JAX is performing backpropagation automatically. It traces the forward pass to build the computational graph, then applies the chain rule in reverse during grad().

You don’t implement backprop by hand. JAX does it for you. But understanding what it computes helps you debug and reason about training dynamics.

NoteOptional Depth: Backpropagation Calculus and JAX

3Blue1Brown’s “Backpropagation calculus” shows the chain-rule bookkeeping more explicitly. This is optional depth, but it is exactly the math that JAX automates when you call jax.grad or eqx.filter_value_and_grad.

The important connection is this: JAX makes backpropagation practical because it records the computational graph of array operations and applies reverse-mode automatic differentiation to that graph. You write the forward model and scalar loss; JAX constructs the derivative computation needed by gradient descent.

Credit: 3Blue1Brown / Grant Sanderson

The Adam Optimizer

Adam (Adaptive Moment Estimation): An optimizer that adapts learning rates for each parameter based on gradient history. Tracks running averages of gradients (first moment) and squared gradients (second moment).

Plain gradient descent uses a fixed learning rate for all parameters. This can be suboptimal — some parameters benefit from larger steps, others from smaller.

Adam adapts the learning rate per parameter based on gradient history:

  • First moment (exponential moving average of gradients): Which direction have we been moving?
  • Second moment (exponential moving average of squared gradients): How variable have gradients been?

Parameters with consistently large gradients in the same direction get boosted. Parameters with noisy, inconsistent gradients get damped.

In practice: Adam is the default optimizer for neural networks. It works well across problems with minimal tuning. Start with learning rate \(\eta = 10^{-3}\).

Optax provides Adam and many other optimizers.


1.8: Weight Initialization — Breaking Symmetry

Priority: Important

Why Initialization Matters

Before training begins, you must set initial values for all weights and biases. This choice profoundly affects training dynamics.

WarningThe Symmetry Problem

What if you initialize all weights to zero?

Every neuron in a layer computes the same function (weighted sum of inputs with weights 0 \(\to\) output 0). They all produce identical outputs. During backprop, they all receive identical gradients. They all update identically. They stay identical forever.

You’ve effectively reduced your 64-neuron layer to a single neuron. The network can’t learn complex features because all neurons are copies of each other.

Random initialization breaks this symmetry. Different neurons start with different weights, compute different functions, receive different gradients, and specialize to detect different features.

Scale Matters Too

Even with random initialization, the scale of initial weights affects training:

Weights too large: Pre-activations \(z = \mathbf{W}\mathbf{x} + \mathbf{b}\) become large. For sigmoid/tanh, this saturates the activation (gradient \(\approx\) 0). For ReLU, values may explode through layers.

Weights too small: Signals shrink as they pass through layers. By the time they reach the output, everything is near zero. Gradients vanish.

Xavier/Glorot initialization: Initialize weights with variance \(\text{Var}(w) = \frac{2}{n_{\text{in}} + n_{\text{out}}}\) where \(n_{\text{in}}\), \(n_{\text{out}}\) are input/output dimensions of the layer. Keeps signal magnitudes stable across layers.

Xavier/Glorot initialization (what Equinox uses by default) sets the right scale:

\[w_{ij} \sim \mathcal{N}\left(0, \frac{2}{n_{\text{in}} + n_{\text{out}}}\right)\]

This keeps the variance of activations roughly constant across layers, enabling stable gradient flow.

Why Different Seeds \(\to\) Different Solutions

Neural network loss landscapes are non-convex — they have many local minima and saddle points. Gradient descent finds a minimum, not the global minimum. Which minimum you find depends on where you start.

NoteNon-Convex Loss Landscape

Imagine two starting points on a hilly loss surface. Both can roll downhill, both can reach low loss, and they can still end in different valleys. Neural-network training is like this: different random initializations can produce different but similarly accurate functions.

Different random seeds produce different initial weights \(\to\) different optimization trajectories \(\to\) different final solutions.

These solutions may all have similar training loss but differ in their predictions, especially in regions with sparse data. This is the foundation for ensemble uncertainty (Section 1.10): train multiple networks with different seeds and measure how much they disagree.


1.9: Practical Training

Priority: Essential

Data Normalization

Standardization (z-score normalization): Transform each feature to zero mean and unit variance: \(\tilde{x} = (x - \mu)/\sigma\). Essential for neural network training.

Neural networks are sensitive to input/output scales. If \(Q_0 \in [0.5, 1.5]\) but \(a \in [50, 200]\), gradients with respect to \(a\)-related weights will dominate, causing slow and unstable learning.

Standardization transforms each feature to zero mean and unit variance:

\[\tilde{x} = \frac{x - \mu_x}{\sigma_x}\]

where \(\mu_x\) and \(\sigma_x\) are computed from your training set only.

WarningThe Cardinal Rule of Normalization

Compute \(\mu\) and \(\sigma\) from the training set only. Apply the same transformation (using training statistics) to:

  • Training inputs and outputs
  • Validation/test inputs and outputs
  • Any new data at inference time

Never use test set statistics for normalization — that would be data leakage, using information from data you’re supposed to predict.

Denormalization: To convert predictions back to physical units: \[x = \tilde{x} \cdot \sigma_x + \mu_x\]

Multi-Output Normalization

Multi-output regression: Predicting multiple target variables simultaneously. Each output should typically be normalized independently.

Your emulator predicts three quantities: \((f_{\text{bound}}, \sigma_v, r_h)\). The MSE loss sums over all outputs:

\[\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \left[ (\hat{f}_{\text{bound},i} - f_{\text{bound},i})^2 + (\hat{\sigma}_{v,i} - \sigma_{v,i})^2 + (\hat{r}_{h,i} - r_{h,i})^2 \right]\]

Problem: If outputs have different scales, this implicitly weights them unequally. If \(r_h\) varies from 20–150 AU (range 130) while \(f_{\text{bound}}\) varies from 0.3–1.0 (range 0.7), squared errors for \(r_h\) are ~35,000\(\times\) larger. The network will focus on \(r_h\) at the expense of \(f_{\text{bound}}\).

Solution: Normalize each output independently to zero mean, unit variance. Then all three contribute roughly equally to the loss.

The normalization pattern (you’ll implement this):

NORMALIZATION PROCEDURE
=======================

1. COMPUTE STATISTICS (from training set only!)
   mu = mean of training data, per feature       # Shape: (3,) for outputs
   sigma = std of training data, per feature     # Shape: (3,) for outputs

2. NORMALIZE any data (train, test, or new):
   x_tilde = (x - mu) / sigma

3. DENORMALIZE predictions back to physical units:
   x = x_tilde * sigma + mu

Critical: Compute \(\mu\) and \(\sigma\) from training data, then use those same values everywhere — for normalizing test data, for denormalizing predictions, and inside your NumPyro model. Store these statistics alongside your trained model.

WarningNormalization and Likelihood Consistency

Training happens most cleanly in normalized units. Scientific interpretation usually happens in physical units. Do not mix them.

If the network predicts normalized outputs, then the inference pipeline must either:

  1. denormalize predictions before comparing to physical observations, or
  2. transform the observation vector and covariance into the same normalized units.

For most final projects, the safest pattern is:

\[ \tilde{\mathbf{x}} = \frac{\mathbf{x} - \boldsymbol{\mu}_x}{\boldsymbol{\sigma}_x}, \qquad \tilde{\mathbf{y}}_{\rm pred} = f_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}), \qquad \mathbf{y}_{\rm pred} = \tilde{\mathbf{y}}_{\rm pred} \odot \boldsymbol{\sigma}_y + \boldsymbol{\mu}_y. \]

Then evaluate the likelihood using \(\mathbf{y}_{\rm pred}\), \(\mathbf{y}_{\rm obs}\), and uncertainties in physical units. Mixing normalized predictions with physical uncertainties is wrong, even if the code runs.

TipPredict – Try – Explain: Normalization Leakage

Predict: What changes if you compute the mean and standard deviation using all data before the split?

Try: Write down which future information leaks into training when test examples help define the normalization.

Explain: Leakage can make an emulator look more accurate than it really is. Training statistics are part of the trained model; validation and test data must be transformed with those fixed values.

Hyperparameters

Hyperparameter: A setting that controls training but isn’t learned from data. Must be chosen before training begins.

Epoch: One complete pass through all training data. If you have 100 training examples and train for 500 epochs, the network sees each example 500 times.

Key hyperparameters for your emulator:

Hyperparameter Starting Value Adjustment Guidance
Learning rate \(10^{-3}\) Decrease by 10\(\times\) if loss oscillates; increase if loss decreases too slowly
Epochs 500–1000 Until loss plateaus on validation set
Hidden layers 2 More layers rarely help for smooth, low-dimensional functions
Neurons per layer 64 Try 32 or 128 if 64 seems too small/large
Batch size Full batch for modest datasets Mini-batches matter only once the dataset is large enough that full-batch updates become slow

Train/Validation/Test Split

With limited simulations, protect three roles:

  • Training set: Fit model parameters
  • Validation/calibration set: Choose hyperparameters, stop training, and estimate likelihood widths
  • Test set: Final held-out report after decisions are fixed

If your dataset is very small, a practical starting point is roughly 70/15/15. If that leaves too few test cases, use repeated random splits or cross-validation for development, but still keep a small final test set untouched for reporting.

WarningThe Test Set Is Sacred

Never use test data for any training decisions:

  • Don’t tune hyperparameters based on test performance
  • Don’t normalize using test statistics
  • Don’t stop training based on test loss

The test set exists solely to evaluate your final model. If you peek at it during development, you lose your ability to honestly assess generalization.

Convergence Criteria

Training is “done” when:

  1. Loss has plateaued: Not decreasing meaningfully over ~100 epochs
  2. Loss is reasonably small: MSE \(\lesssim 0.01\) on normalized data means typical errors < 10% of standard deviation
  3. No pathologies: No NaN values, no loss increasing

The Training Loop

Training loop: The iterative process of computing predictions, calculating loss, computing gradients, and updating parameters. Repeats for many epochs until convergence.

A typical training loop:

# Pseudocode
losses = []
for epoch in range(num_epochs):
    # Forward pass: compute predictions
    predictions = model(params, x_train)

    # Compute loss
    loss = jnp.mean((predictions - y_train)**2)
    losses.append(loss)

    # Backward pass: compute gradients
    gradients = jax.grad(loss_fn)(params)

    # Update parameters
    params = update_with_adam(params, gradients)

    if epoch % 100 == 0:
        print(f"Epoch {epoch}: loss = {loss:.6f}")

# ALWAYS plot the loss curve
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.yscale("log")

Always plot the loss curve. Its shape reveals problems that printed numbers hide.

NoteLoss Curve Diagnostics
Curve shape Likely diagnosis First response
Smooth decrease, then plateau Healthy training Stop when validation loss also plateaus
Large oscillations or divergence Learning rate too high Decrease learning rate by 10\(\times\)
Very slow decrease Learning rate too low Increase learning rate or train longer
Training loss falls while validation loss rises Overfitting Early stop, simplify, or add data

Debugging Checklist

ImportantWhen Training Goes Wrong

Before training — verify setup:

During training — monitor health:

After training — validate results:

Common fixes:

Symptom Likely Cause Fix
Loss stays constant Learning rate too small Increase by 10\(\times\)
Loss oscillates wildly Learning rate too large Decrease by 10\(\times\)
Loss becomes NaN Numerical instability Lower LR, check for division by zero
Loss plateaus high Architecture too simple, or bug Add neurons/layers, check code
Train loss good, test loss bad Overfitting Early stopping, more data, or accept it

1.10: Uncertainty via Ensembles

Priority: Essential

The Problem with Single Networks

A trained neural network gives point predictions — single numbers with no uncertainty. But scientific applications need uncertainty quantification:

  • How confident is this prediction?
  • Are we extrapolating outside the training distribution?
  • How should prediction uncertainty propagate to inference?

A single network provides none of this. Worse, networks can be confidently wrong, especially when extrapolating beyond training data. An ensemble helps, but ensemble spread is not automatically the total uncertainty needed for inference.

Deep Ensembles

Ensemble: A collection of models that make independent predictions. Disagreement among members quantifies uncertainty. For neural networks, different random initializations suffice to create an ensemble.

The simplest approach to neural network uncertainty: train multiple networks with different random seeds.

Why does this work? From Section 1.8: different initializations lead to different local minima. These solutions may give similar predictions where training data is dense but diverge where data is sparse.

Procedure:

  1. Train \(M\) networks with different random seeds (typically \(M = 3\)–5)
  2. For any input \(\mathbf{x}\):
    • Get predictions from each: \(\hat{y}_1(\mathbf{x}), \ldots, \hat{y}_M(\mathbf{x})\)
    • Mean prediction: \(\bar{y}(\mathbf{x}) = \frac{1}{M} \sum_{m=1}^{M} \hat{y}_m(\mathbf{x})\)
    • Uncertainty: \(\sigma_y(\mathbf{x}) = \sqrt{\frac{1}{M-1} \sum_{m=1}^{M} (\hat{y}_m(\mathbf{x}) - \bar{y}(\mathbf{x}))^2}\)

Epistemic uncertainty: Uncertainty due to limited knowledge — here, limited training data. Reducible with more data. Distinct from aleatoric uncertainty (irreducible noise in the process itself).

The spread \(\sigma_y\) estimates one part of epistemic uncertainty — uncertainty from limited training data and training trajectory. It is a warning light, not a complete error budget.

TipPredict – Try – Explain: Ensemble Uncertainty

Predict: Where should an ensemble disagree most: near dense training points, near domain edges, or far outside the sampled region?

Try: For one output such as \(f_{\rm bound}\), plot a one-dimensional slice in \(Q_0\) at fixed \(a\): training points, each ensemble member, the ensemble mean, and a shaded \(\pm 2\sigma\) band.

Explain: Ensemble spread is a warning light for epistemic uncertainty. It is most useful when you compare it to where the training data actually live.

Interpreting Ensemble Spread

Uncertainty should be high:

  • Near edges of training distribution (extrapolation begins)
  • In gaps between training points
  • Where the function is complex/rapidly varying

Uncertainty should be low:

  • In regions densely covered by training data
  • Where all ensemble members agree
  • Where the function is simple/slowly varying
NoteConnection to Module 5: Epistemic vs. Aleatoric

Remember from Module 5:

Epistemic uncertainty is reducible with more data. It represents “we’re uncertain because we haven’t seen enough examples.” Ensemble spread captures this — with more training data, ensemble members would agree better.

Aleatoric uncertainty is irreducible noise in the process itself. For your emulator, this can include scatter from different random cluster realizations at fixed macro-parameters \((Q_0, a)\). The exact same initial conditions and deterministic code path should reproduce the same result; the scatter comes from changing the realization, not from rerunning an identical state.

Ensembles capture epistemic uncertainty but not aleatoric uncertainty. The ensemble members all learned from the same training data, so they share the same aleatoric floor. More sophisticated approaches (predicting variance directly, Bayesian neural networks) can capture both, but ensembles are a practical first step.

WarningEnsemble Spread Is Not Total Uncertainty

For inference, a more honest uncertainty budget is

\[ \Sigma_{\rm total} = \Sigma_{\rm obs} + \Sigma_{\rm emu} + \Sigma_{\rm sim}. \]

Here \(\Sigma_{\rm obs}\) represents observational or target-measurement uncertainty, \(\Sigma_{\rm emu}\) represents emulator approximation uncertainty, and \(\Sigma_{\rm sim}\) represents simulator stochasticity or numerical uncertainty. Ensemble spread can help estimate part of \(\Sigma_{\rm emu}\), but it does not automatically include the other terms. It must be checked against held-out residuals.

TipPredict – Try – Explain: Overconfident Inference

Predict: If your emulator is accurate in the center of parameter space but poor near the edges, what happens to the posterior over \((Q_0, a)\) if you use an unrealistically tiny uncertainty?

Try: Sketch a likelihood surface with a narrow peak in the wrong location.

Explain: An overconfident emulator can produce an overconfident posterior. Calibration is what prevents a fast approximation from becoming a fast wrong answer.

Ensemble in Practice

The ensemble workflow is conceptually simple:

ENSEMBLE WORKFLOW
=================

Training:
1. Split random key into M independent keys
2. Train M models, each with a different key
3. Store all trained models

Prediction:
1. Get predictions from all M models
2. Stack predictions into array of shape (M, num_outputs)
3. Mean = jnp.mean(predictions, axis=0)
4. Std = jnp.std(predictions, axis=0, ddof=1)  # ddof=1 for unbiased

Implementation hints:

  • jax.random.split(key, M) creates M independent keys
  • Store models in a Python list
  • For single input: preds = jnp.array([model(x) for model in models])
  • For batched inputs: use jax.vmap inside the list comprehension

For your final project, \(M = 5\) ensemble members is a good balance between computational cost and uncertainty quality. You’ll implement this yourself — the pattern above tells you what to compute, not how to code it.


Conceptual Checkpoint: Training and Uncertainty

Before moving to implementation, verify your understanding:

ImportantSelf-Assessment
  1. Loss interpretation: If your MSE loss on normalized data is 0.04, roughly how large are typical prediction errors relative to output standard deviations?

  2. Learning rate: Describe what happens to the loss curve if the learning rate is (a) too large, (b) too small, (c) well-chosen.

  3. Normalization: Why must you use training set statistics (not test set) when normalizing test data?

  4. Multi-output: Your emulator predicts three outputs. Before normalization, \(r_h\) has 100\(\times\) larger variance than \(f_{\rm bound}\). What problem does this cause, and how do you fix it?

  5. Ensembles: You train 5 networks on the same data but with different random seeds. Why do they converge to different solutions?

  6. Uncertainty interpretation: Your ensemble predicts \(r_h = 85 \pm 3\) AU for one input and \(r_h = 120 \pm 20\) AU for another. What does the larger uncertainty in the second case tell you?

  1. MSE = 0.04 means RMSE = 0.2. On normalized data (std = 1), typical errors are ~20% of one standard deviation — quite good.

    1. Too large: loss oscillates wildly, may diverge. (b) Too small: loss decreases very slowly, may not reach good minimum in reasonable time. (c) Well-chosen: smooth decrease, then plateau.
  2. Using test statistics would be data leakage — incorporating information from data you’re supposed to predict. The model must work with only training-time information.

  3. Without normalization, the loss is dominated by \(r_h\) (larger squared errors). The network focuses on \(r_h\) at the expense of \(f_{\rm bound}\). Fix: normalize each output independently to unit variance.

  4. The loss landscape has many local minima. Different initial weights lead to different optimization trajectories, ending at different minima with similar loss but different predictions.

  5. The second prediction has higher epistemic uncertainty. The ensemble members disagree more, likely because that input region had sparser training data. You should be less confident in the second prediction.


1.11: The JAX Ecosystem — Equinox and Optax

Priority: Essential

Equinox: Neural Networks as PyTrees

Equinox: A JAX library for neural networks by Patrick Kidger. Models are PyTrees (nested containers JAX understands), enabling seamless use of jax.grad, jax.jit, jax.vmap.

Equinox makes neural networks fit naturally into JAX’s functional paradigm:

Models are PyTrees: Parameters live in a tree structure. JAX transformations (grad, jit, vmap) work seamlessly because they know how to traverse PyTrees.

Models are callables: An Equinox model is a function. You call it: output = model(input).

Explicit state: No hidden mutable variables. All parameters are explicit. This makes debugging easier and enables JAX’s transformations.

Learning from Documentation

ImportantYour Primary Resources

The final project requires you to learn Equinox and Optax from their documentation. This is intentional — reading documentation and adapting examples is a core professional skill.

Equinox (https://docs.kidger.site/equinox/):

  • “Getting Started”: How to define an eqx.Module with __init__ and __call__
  • “Train a Neural Network”: Complete training loop pattern you’ll adapt

Optax (https://optax.readthedocs.io/en/latest/):

  • “Quick Start”: How to create optimizers and apply updates

Your task: Work through these tutorials before starting your emulator. Run the examples. Modify them. Break them and fix them. Then adapt the patterns to your problem.

Glass-Box Exercise: Manual Forward Pass

Before using any library, make sure you understand what’s happening inside:

Implement a forward pass manually using only JAX primitives:

import jax.numpy as jnp

def manual_forward(params, x):
    """Forward pass through a 2-64-64-3 MLP.

    params: dict with keys 'W1', 'b1', 'W2', 'b2', 'W3', 'b3'
    x: input array of shape (2,)
    Returns: output array of shape (3,)
    """
    # Layer 1: linear transformation + ReLU
    z1 = params['W1'] @ x + params['b1']
    h1 = jnp.maximum(0, z1)

    # Layer 2: linear transformation + ReLU
    z2 = params['W2'] @ h1 + params['b2']
    h2 = jnp.maximum(0, z2)

    # Output layer: linear only (no activation)
    y = params['W3'] @ h2 + params['b3']
    return y

Verify that JAX can differentiate this:

def loss_fn(params, x, y_true):
    y_pred = manual_forward(params, x)
    return jnp.mean((y_pred - y_true)**2)

grad_fn = jax.grad(loss_fn)  # This works!

This exercise shows that neural networks are just compositions of simple operations. Equinox handles the bookkeeping — but you should understand what it’s managing.

A Teaching Example: Learning \(\sin(x)\)

To illustrate Equinox patterns without solving your final project, let’s build a tiny network that learns the sine function. This is NOT your emulator architecture — adapt the concepts, don’t copy the code.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax

# A SIMPLE 1D example - your emulator will differ!
class TinySinNet(eqx.Module):
    layer1: eqx.nn.Linear
    layer2: eqx.nn.Linear

    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.layer1 = eqx.nn.Linear(1, 16, key=k1)  # 1 input
        self.layer2 = eqx.nn.Linear(16, 1, key=k2)  # 1 output

    def __call__(self, x):
        x = jax.nn.relu(self.layer1(x))
        return self.layer2(x)
WarningHow Your Emulator Differs

TinySinNet has 2 layers (1 \(\to\) 16 \(\to\) 1) with only one hidden layer.

Your emulator needs 3 layers (2 \(\to\) 64 \(\to\) 64 \(\to\) 3):

  • 2 inputs, not 1
  • Two hidden layers of 64 neurons each, not one layer of 16
  • 3 outputs, not 1
  • ReLU activation after each hidden layer

Adapt the pattern (how to define eqx.Module, split keys, write __call__), not the specific architecture.

Worked Mini-Example: A Tiny End-to-End Regressor

The example below is intentionally small. It shows the full workflow on a toy regression problem: normalize data, define an Equinox model, train it with Optax, and convert predictions back to physical units.

NoteCode Literacy: Tiny End-to-End Regressor

Purpose: Show the full Equinox/Optax training pattern on a toy problem before you adapt it to the emulator.

Inputs and outputs: The input is a one-dimensional array \(x\); the target is \(\sin(3x)\); the output is a trained model plus a physical-unit RMSE.

Common pitfall: Do not copy the architecture into the final project. Copy the workflow: train-only normalization, vmap for batches, filtered gradients, optimizer updates, denormalized evaluation.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax

# Toy dataset: x -> sin(3x)
key = jax.random.PRNGKey(0)
x = jnp.linspace(-1.0, 1.0, 64)[:, None]
y = jnp.sin(3.0 * x)

# Normalize using training statistics
x_mean, x_std = jnp.mean(x, axis=0), jnp.std(x, axis=0)
y_mean, y_std = jnp.mean(y, axis=0), jnp.std(y, axis=0)
x_norm = (x - x_mean) / x_std
y_norm = (y - y_mean) / y_std

class TinyRegressor(eqx.Module):
    l1: eqx.nn.Linear
    l2: eqx.nn.Linear

    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.l1 = eqx.nn.Linear(1, 32, key=k1)
        self.l2 = eqx.nn.Linear(32, 1, key=k2)

    def __call__(self, x):
        x = jax.nn.relu(self.l1(x))
        return self.l2(x)

model = TinyRegressor(key)
optimizer = optax.adam(1e-2)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

def loss_fn(model, x_batch, y_batch):
    preds = jax.vmap(model)(x_batch)
    return jnp.mean((preds - y_batch) ** 2)

@eqx.filter_jit
def train_step(model, opt_state, x_batch, y_batch):
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x_batch, y_batch)
    updates, opt_state = optimizer.update(
        grads, opt_state, eqx.filter(model, eqx.is_array)
    )
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

for epoch in range(500):
    model, opt_state, loss = train_step(model, opt_state, x_norm, y_norm)

pred_norm = jax.vmap(model)(x_norm)
pred = pred_norm * y_std + y_mean
rmse = jnp.sqrt(jnp.mean((pred - y) ** 2))
print(f"RMSE in physical units: {rmse.item():.4f}")

What to notice:

  • The model is just a callable function with explicit parameters.
  • Normalization happens outside the model and must be saved for later use.
  • jax.vmap(model) handles the batch dimension.
  • eqx.filter_value_and_grad and optax.adam give you the core training loop.
  • After training, you must denormalize predictions before interpreting them.

For your final project, the structure is the same. The differences are the scientific dataset, the emulator architecture, the held-out evaluation, and the uncertainty/inference layers you build on top.

Key patterns to notice (these transfer to your emulator):

  1. Inherit from eqx.Module — makes your class a PyTree
  2. Declare layers as typed attributes — Equinox tracks these
  3. Split keys for each layer — ensures independent initialization
  4. __call__ defines forward pass — apply layers and activations

Essential Equinox Patterns

These small patterns are the building blocks. You’ll combine them for your emulator.

Pattern 1: Creating layers

# eqx.nn.Linear(in_features, out_features, key=...)
layer = eqx.nn.Linear(64, 32, key=some_key)

Pattern 2: Applying activations

# Activations are just JAX functions
x = jax.nn.relu(layer(x))  # ReLU
x = jax.nn.tanh(layer(x))  # Tanh (if you needed it)

Pattern 3: Filtering arrays from models

# Equinox models contain arrays (parameters) and non-arrays (structure)
# Many operations need just the arrays:
params = eqx.filter(model, eqx.is_array)

Pattern 4: Computing gradients

# eqx.filter_value_and_grad works like jax.value_and_grad
# but handles models with non-array components
loss, grads = eqx.filter_value_and_grad(loss_fn)(model)

Pattern 5: Applying updates

# After computing updates from optimizer:
new_model = eqx.apply_updates(model, updates)

Pattern 6: JIT-compiling with models

# Use eqx.filter_jit instead of jax.jit for functions involving eqx.Module
@eqx.filter_jit
def train_step(model, ...):
    ...

The Training Loop Structure

Here’s the structure of a training loop — you’ll fill in the details:

TRAINING LOOP PSEUDOCODE
========================

1. INITIALIZE
   - Create model with random key
   - Create optimizer (e.g., optax.adam)
   - Initialize optimizer state from model parameters

2. FOR EACH EPOCH:
   a. FORWARD PASS
      - Compute predictions: pred = model(x_train)
      - (Use jax.vmap to handle batches)

   b. COMPUTE LOSS
      - MSE: loss = mean((pred - y_train)**2)

   c. BACKWARD PASS
      - Compute gradients w.r.t. model parameters
      - (Use eqx.filter_value_and_grad)

   d. UPDATE
      - Get updates from optimizer
      - Apply updates to model

   e. LOG
      - Store loss for plotting
      - Print progress periodically

3. PLOT LOSS CURVE (always!)

4. RETURN trained model

Your task: Translate this pseudocode into working Equinox/Optax code. The documentation tutorials show exactly how.

Training an Ensemble: The Idea

For uncertainty quantification, you’ll train multiple networks:

ENSEMBLE PSEUDOCODE
===================

1. Split your random key into M keys (one per model)

2. FOR EACH key:
   - Train a model using that key
   - (Different key -> different initialization -> different solution)
   - Store the trained model

3. TO PREDICT with uncertainty:
   - Get prediction from each model
   - Mean = average of predictions
   - Uncertainty = standard deviation of predictions

Implementation hints:

  • jax.random.split(key, n) gives you n independent keys
  • Store models in a Python list
  • For ensemble prediction, iterate over models and stack results with jnp.array([...])
  • Use jnp.mean(..., axis=0) and jnp.std(..., axis=0, ddof=1) to aggregate
WarningWhat You Must Figure Out

The public final-project docs and starter patterns give you the workflow, but you still need to make the key implementation decisions. You need to:

  1. Define your model class — How many layers? What dimensions? Where do activations go?

  2. Write the training step — How do you combine eqx.filter_value_and_grad, optimizer.update, and eqx.apply_updates?

  3. Handle batches — When do you use jax.vmap?

  4. Train the ensemble — How do you manage multiple models with different initializations?

The Equinox “Train a Neural Network” tutorial shows a complete example. Study it, understand it, then adapt it to your simulator-and-emulator workflow — don’t copy it blindly.

TipSaving Your Trained Models

You’ll need to save trained models for use in the inference phase. Equinox provides serialization via eqx.tree_serialise_leaves and eqx.tree_deserialise_leaves.

Important: Save your normalization statistics (mean, std for inputs and outputs) alongside your models — you’ll need them for inference. A simple approach is to save everything in a dictionary or use separate files.

See the Equinox documentation on serialization for details.


1.12: From Emulator to Inference

Priority: Essential

The Goal

You’ve trained a fast emulator. Now use it for Bayesian inference: given observed or held-out cluster summary statistics, infer the initial conditions that produced them.

The inference target is

\[ p(Q_0, a \,|\, \mathbf{y}_{\rm obs}), \]

where

\[ \mathbf{y}_{\rm obs} = \begin{bmatrix} f_{\rm bound,obs} \\ \sigma_{v,\rm obs} \\ r_{h,\rm obs} \end{bmatrix}. \]

This posterior asks: after seeing the cluster summaries, which initial virial ratios and scale radii remain plausible?

The emulator replaces the expensive N-body simulation inside the likelihood. NumPyro’s NUTS sampler will call that likelihood thousands of times, so speed matters. Scientific trust, however, comes from the likelihood and uncertainty model, not from speed alone.

From Emulator Prediction to Likelihood

A trained emulator gives a prediction:

\[ \hat{\mathbf{y}} = f_{\hat{\boldsymbol{\theta}}}(Q_0, a). \]

For the final project,

\[ \hat{\mathbf{y}} = \begin{bmatrix} \hat{f}_{\rm bound} \\ \hat{\sigma}_v \\ \hat{r}_h \end{bmatrix}. \]

But inference requires more than a prediction. It requires a probability:

\[ p(\mathbf{y}_{\rm obs} \,|\, Q_0, a). \]

This probability is the likelihood. It answers:

If the initial conditions were \((Q_0, a)\), how plausible would the observed summary statistics be?

To construct it, we need an error model. A simple first model is

\[ \mathbf{y}_{\rm obs} = f_{\hat{\boldsymbol{\theta}}}(Q_0, a) + \boldsymbol{\epsilon}, \qquad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \Sigma_{\rm total}). \]

Then

\[ p(\mathbf{y}_{\rm obs} \,|\, Q_0, a, \hat{\boldsymbol{\theta}}, \Sigma_{\rm total}) = \mathcal{N} \left( \mathbf{y}_{\rm obs} \,\middle|\, f_{\hat{\boldsymbol{\theta}}}(Q_0, a), \Sigma_{\rm total} \right). \]

If we assume independent errors for the three summary statistics, then

\[ \Sigma_{\rm total} = \begin{pmatrix} \sigma^2_{f_{\rm bound}} & 0 & 0 \\ 0 & \sigma^2_{\sigma_v} & 0 \\ 0 & 0 & \sigma^2_{r_h} \end{pmatrix}. \]

This diagonal assumption is not automatic. It says errors in bound fraction, velocity dispersion, and half-mass radius are independent after conditioning on the model. If held-out residuals show correlated errors, this likelihood is too simple.

Bayes’ theorem then gives the object we actually want:

\[ p(Q_0, a \,|\, \mathbf{y}_{\rm obs}) \propto p(\mathbf{y}_{\rm obs} \,|\, Q_0, a) p(Q_0, a). \]

The prior \(p(Q_0, a)\) encodes physically plausible initial conditions and should also keep the sampler inside the domain where the emulator has scientific authority. The likelihood encodes agreement with the observed cluster summaries. The posterior is the set of initial conditions still plausible after seeing the data.

NoteLikelihood vs. Posterior

The likelihood and posterior condition in opposite directions:

  • Likelihood: \(p(\mathbf{y}_{\rm obs} \,|\, Q_0, a)\) asks how plausible the data are if the initial conditions are true.
  • Posterior: \(p(Q_0, a \,|\, \mathbf{y}_{\rm obs})\) asks how plausible the initial conditions are after seeing the data.

NumPyro samples the posterior. To do that, you must provide priors and a likelihood.

The Conceptual Framework

Probabilistic programming: A paradigm where you write a generative model (how data is produced from parameters), and the framework handles inference automatically.

In NumPyro, you write a generative model — code describing how observations arise from parameters. The structure directly mirrors Bayes’ theorem:

GENERATIVE MODEL STRUCTURE
==========================

1. PRIOR: What do we believe about parameters before seeing data?
   - Sample Q0 from some distribution (e.g., Uniform over training range)
   - Sample a from some distribution

2. FORWARD MODEL: Given parameters, what do we predict?
   - Normalize the sampled (Q0, a) using TRAINING statistics
   - Pass through your emulator to get predicted (f_bound, sigma_v, r_h)
   - Denormalize predictions back to physical units

3. LIKELIHOOD: How probable are actual observations given predictions?
   - Compare predictions to observations
   - Account for observational, simulation, and emulator uncertainty

The key insight: Your emulator serves as the forward model. Instead of running an expensive N-body simulation for each \((Q_0, a)\) the sampler proposes, you call your fast neural network.

What You Need to Figure Out

ImportantNumPyro Learning Path

The final project requires you to build a NumPyro model. Here’s what to learn:

NumPyro documentation (https://num.pyro.ai/):

  • “Getting Started”: Basic model structure with numpyro.sample
  • “Bayesian Regression”: Example closest to your task
  • MCMC and NUTS: How to run inference

Key NumPyro primitives you’ll use:

  • numpyro.sample("name", distribution) — sample a parameter from a prior
  • numpyro.sample("obs", distribution, obs=data) — define likelihood given observed data
  • dist.Uniform(low, high) — uniform prior
  • dist.Normal(loc, scale) — Gaussian likelihood
  • NUTS(model) — the sampler (you know this from Module 5!)
  • MCMC(kernel, num_warmup, num_samples) — run the chain

The Normalization Challenge

Your emulator was trained on normalized data. Your NumPyro model must handle this correctly:

NORMALIZATION FLOW
==================

Physical parameters (Q0, a) from prior
          |
    Normalize using TRAINING statistics
          |
Normalized inputs to emulator
          |
    Emulator predicts normalized outputs
          |
    Denormalize using TRAINING statistics
          |
Physical predictions (f_bound, sigma_v, r_h)
          |
    Compare to physical observations in likelihood
WarningCritical Detail

You must use the same normalization statistics (mean and std from training set) that you used during training. Store these when you train your emulator, and pass them to your inference model.

A common bug: accidentally using different statistics, which silently produces wrong predictions.

def predict_physical(model, norm_stats, x_physical):
    """Predict physical summary statistics from physical inputs.

    Purpose: apply the trained emulator inside an inference model while keeping
        normalization consistent with training.
    Inputs: model, saved normalization statistics, and x_physical with
        components [Q0, a].
    Output: y_physical = [f_bound, sigma_v, r_h] in physical units.
    Common pitfall: using validation/test statistics here instead of saved
        training statistics changes the forward model silently.
    """
    x_norm = (x_physical - norm_stats.x_mean) / norm_stats.x_std
    y_norm = model(x_norm)
    y_physical = y_norm * norm_stats.y_std + norm_stats.y_mean
    return y_physical

def diagonal_gaussian_loglike(y_obs, y_pred, sigma_vec):
    """Compute a diagonal Gaussian log-likelihood in physical units.

    Purpose: score agreement between observed and emulator-predicted summaries.
    Inputs: y_obs, y_pred, and sigma_vec in the same physical output units.
    Output: scalar log-likelihood.
    Common pitfall: sigma_vec has one entry per output, not one universal number
        for dimensionless, velocity, and length summaries.
    """
    residual = y_obs - y_pred
    return -0.5 * jnp.sum(
        (residual / sigma_vec) ** 2 + jnp.log(2 * jnp.pi * sigma_vec ** 2)
    )

Choosing the Likelihood Width

The likelihood requires an uncertainty model. A useful starting point is a diagonal vector,

\[ \boldsymbol{\sigma}_{\rm total} = \begin{bmatrix} \sigma_{f_{\rm bound}} \\ \sigma_{\sigma_v} \\ \sigma_{r_h} \end{bmatrix}, \]

with one component for each output. These entries must be in the same physical units as \(\mathbf{y}_{\rm obs}\).

For synthetic observations from your simulations, this uncertainty may include:

  1. Emulator error: the network does not perfectly predict simulations.
  2. Simulation stochasticity: different random seeds or initial realizations produce scatter even for fixed \((Q_0, a)\).
  3. Numerical uncertainty: timestep, integration, or finite-precision effects perturb the summary statistic.
  4. Observational uncertainty: only relevant if you compare to real observed cluster summaries.

Practical baseline: Use validation/calibration-set RMSE for each output as a first estimate of \(\boldsymbol{\sigma}_{\rm total}\). Keep the test set untouched for final reporting and parameter-recovery checks after the likelihood width is fixed.

COMPUTING sigma_total
=====================

1. Get emulator predictions on validation/calibration set (normalized)
2. Denormalize to physical units
3. Compare to true validation/calibration values
4. RMSE = sqrt(mean((pred - true)**2)) for each output
5. Use these three RMSE values as an initial sigma_total
NoteUncertainty Calibration

If \(\boldsymbol{\sigma}_{\rm total}\) is too small, posteriors will be overconfident and may miss true values. If too large, posteriors will be vague and uninformative.

Validation/calibration RMSE is a principled starting point. You can check calibration: do about 95% of your validation posteriors contain the true parameter values within their 95% credible intervals?

Using an Ensemble in the Likelihood

If you train an ensemble, you can estimate an input-dependent emulator covariance. For \(M\) ensemble members,

\[ \boldsymbol{\mu}_{\rm emu}(\mathbf{x}) = \frac{1}{M} \sum_{m=1}^{M} f_{\boldsymbol{\theta}_m}(\mathbf{x}), \]

and

\[ \Sigma_{\rm emu}(\mathbf{x}) = \frac{1}{M - 1} \sum_{m=1}^{M} \left( f_{\boldsymbol{\theta}_m}(\mathbf{x}) - \boldsymbol{\mu}_{\rm emu}(\mathbf{x}) \right) \left( f_{\boldsymbol{\theta}_m}(\mathbf{x}) - \boldsymbol{\mu}_{\rm emu}(\mathbf{x}) \right)^T. \]

A simple ensemble-aware likelihood is then

\[ p(\mathbf{y}_{\rm obs} \,|\, \mathbf{x}) = \mathcal{N} \left( \mathbf{y}_{\rm obs} \,\middle|\, \boldsymbol{\mu}_{\rm emu}(\mathbf{x}), \Sigma_{\rm obs} + \Sigma_{\rm emu}(\mathbf{x}) \right). \]

This is still an approximation. It treats ensemble spread as a proxy for emulator uncertainty, so it must be checked with held-out simulations.

TipPredict – Try – Explain: Degeneracy in \((Q_0, a)\)

Predict: Could a larger initial scale radius with lower virial ratio produce a final half-mass radius similar to a more compact but dynamically hotter cluster?

Try: Sketch likelihood contours in \((Q_0, a)\) space for one summary statistic, then sketch how adding \(f_{\rm bound}\) and \(\sigma_v\) might narrow or rotate the contour.

Explain: Multiple initial conditions can produce similar summaries. The posterior may be an elongated ridge, not a round blob. This is why multiple summary statistics and honest uncertainty matter.

Validation: Parameter Recovery

The critical test of your inference pipeline:

PARAMETER RECOVERY TEST
=======================

1. Pick a simulation from your TEST set
   - You know the true (Q0, a) that generated it
   - You have its summary statistics (f_bound, sigma_v, r_h)

2. Treat summary statistics as "observations"

3. Run inference to get posterior over (Q0, a)

4. Check: Do 95% credible intervals contain true values?
   - Compute 2.5th and 97.5th percentiles of posterior samples
   - True value should fall within this range

5. Repeat for several test cases
   - If ~95% contain truth, your pipeline is well-calibrated
   - If fewer, $\boldsymbol{\sigma}_{\rm total}$ may be too small (overconfident)
   - If more, $\boldsymbol{\sigma}_{\rm total}$ may be too large (uninformative)

Visualization: Create corner plots showing the joint posterior over \((Q_0, a)\) with true values marked. The true values should fall within the 95% contours.

NoteParameter-Recovery Plot

For each held-out recovery test, make a corner plot for \((Q_0, a)\) with the true values marked. The 68% and 95% contours should be interpreted only after the likelihood width has been calibrated on validation data.

Connecting the Pieces

Here’s how all the components fit together in your final project:

COMPLETE INFERENCE PIPELINE
===========================

┌─────────────────────────────────────────────────────────┐
│  TRAINING PHASE (done before inference)                 │
├─────────────────────────────────────────────────────────┤
│  1. Generate N-body training data                       │
│  2. Compute and store normalization statistics          │
│  3. Train ensemble of emulators                         │
│  4. Estimate sigma_total from validation residuals       │
│     and reserve test set for final reporting            │
└─────────────────────────────────────────────────────────┘
                          ↓
┌─────────────────────────────────────────────────────────┐
│  INFERENCE PHASE                                        │
├─────────────────────────────────────────────────────────┤
│  1. Define NumPyro model:                               │
│     - Priors on (Q0, a)                                 │
│     - Emulator as forward model (with normalization!)   │
│     - Gaussian likelihood with sigma_total              │
│                                                         │
│  2. Run NUTS:                                           │
│     - Warmup: adapts step size and mass matrix          │
│     - Sampling: collects posterior samples              │
│                                                         │
│  3. Analyze posterior:                                  │
│     - Summary statistics (mean, std, credible intervals)│
│     - Corner plots                                      │
│     - Parameter recovery validation                     │
└─────────────────────────────────────────────────────────┘
TipWhy This Is Fast

Remember the calculation from Section 1.1? Direct inference with N-body simulations can become infeasible because NUTS may need many forward-model calls.

With your emulator:

  • NUTS can call the forward model many thousands of times during warmup and sampling
  • Each JIT-compiled emulator call is typically much faster than rerunning a simulation
  • The actual speedup depends on model size, hardware, vectorization, and sampler settings

The bottleneck usually becomes sampler overhead rather than the forward model. Report measured wall-clock time from your own machine instead of quoting a universal speedup factor.

This is why JIT compilation matters: the @eqx.filter_jit decorator lets JAX compile repeated model calls. It often gives large speedups, but the honest claim is the one you benchmark for your pipeline.

ImportantMinimum Viable Final-Project Workflow

Before you add optional sophistication, make this baseline work:

If these pieces are solid, the project is scientifically defensible. Optional uncertainty methods should strengthen this workflow, not distract from it.


1.13: Synthesis

Priority: Essential

The Complete Workflow

You’ve learned a powerful workflow for modern computational science:

Show code
flowchart TD
    subgraph Data["1. Generate Training Data"]
        sims["Run JAX-native Leapfrog simulations\n(Latin Hypercube in Q0, a)"]
        stats["Extract summary statistics\n(f_bound, sigma_v, r_h)"]
        sims --> stats
    end

    subgraph Prep["2. Data Preparation"]
        split["Train/validation/test split"]
        norm["Normalize to zero mean,\nunit variance"]
        split --> norm
    end

    subgraph Arch["3. Design Architecture"]
        mlp["MLP: 2 -> 64 -> 64 -> 3"]
        act["ReLU hidden, linear output"]
    end

    subgraph Train["4. Train Network"]
        init["Xavier initialization"]
        opt["Adam optimizer"]
        loss["MSE loss"]
        init --> opt --> loss
    end

    subgraph Ens["5. Build Ensemble"]
        multi["Train 5 models\n(different seeds)"]
        uncert["Estimate emulator spread\nthen check calibration"]
        multi --> uncert
    end

    subgraph Eval["6. Evaluate"]
        test["Held-out accuracy"]
        plots["Predicted vs. true\nand residual diagnostics"]
        edges["Domain and edge behavior"]
    end

    subgraph Infer["7. Inference"]
        numpyro["NumPyro model:\nprior -> emulator -> physical likelihood"]
        nuts["NUTS sampling"]
        post["Posterior over (Q0, a)"]
        numpyro --> nuts --> post
    end

    Data --> Prep --> Arch --> Train --> Ens --> Eval --> Infer

flowchart TD
    subgraph Data["1. Generate Training Data"]
        sims["Run JAX-native Leapfrog simulations\n(Latin Hypercube in Q0, a)"]
        stats["Extract summary statistics\n(f_bound, sigma_v, r_h)"]
        sims --> stats
    end

    subgraph Prep["2. Data Preparation"]
        split["Train/validation/test split"]
        norm["Normalize to zero mean,\nunit variance"]
        split --> norm
    end

    subgraph Arch["3. Design Architecture"]
        mlp["MLP: 2 -> 64 -> 64 -> 3"]
        act["ReLU hidden, linear output"]
    end

    subgraph Train["4. Train Network"]
        init["Xavier initialization"]
        opt["Adam optimizer"]
        loss["MSE loss"]
        init --> opt --> loss
    end

    subgraph Ens["5. Build Ensemble"]
        multi["Train 5 models\n(different seeds)"]
        uncert["Estimate emulator spread\nthen check calibration"]
        multi --> uncert
    end

    subgraph Eval["6. Evaluate"]
        test["Held-out accuracy"]
        plots["Predicted vs. true\nand residual diagnostics"]
        edges["Domain and edge behavior"]
    end

    subgraph Infer["7. Inference"]
        numpyro["NumPyro model:\nprior -> emulator -> physical likelihood"]
        nuts["NUTS sampling"]
        post["Posterior over (Q0, a)"]
        numpyro --> nuts --> post
    end

    Data --> Prep --> Arch --> Train --> Ens --> Eval --> Infer

Connections Across the Course

Concept Where You Learned It How It Appears Here
Models as compression Module 1 Emulator compresses input \(\to\) output relationship into weights
Moments & summary statistics Module 1 \(f_{\rm bound}\), \(\sigma_v\), \(r_h\) compress simulation output
Likelihood & MSE Module 5 MSE can be derived from a Gaussian residual model under specific assumptions
Gradient-based inference Module 5 (HMC) Training uses gradients; NUTS uses gradients
Automatic differentiation The JAX module Backprop = reverse-mode autodiff (JAX handles it)
JIT compilation The JAX module Essential for fast emulator evaluation in inference
Functional programming The JAX module Equinox models are pure functions with explicit state
N-body dynamics Module 3, Project 2, and the final-project JAX rebuild The simulator map your emulator approximates

What Makes a Good Emulator?

Accuracy: Predictions match simulations in the training region

  • Test RMSE should be small relative to output variation
  • Predicted vs. true plots cluster along \(y = x\)

Calibration: Uncertainty estimates are meaningful

  • About 95% of true values fall within calibrated 95% prediction intervals
  • Uncertainty increases where training data is sparse
  • Posterior recovery tests have credible intervals with sensible empirical coverage

Speed: Fast enough for inference

  • Microseconds per prediction (vs. seconds for simulation)
  • JIT compilation is essential

Robustness: Sensible behavior at boundaries

  • No claims outside the training range unless explicitly tested
  • Ensemble uncertainty increases appropriately
  • Residual diagnostics do not reveal systematic bias across \(Q_0\), \(a\), or predicted values

Looking Ahead: The Final Project

You’re now ready for the final project. You’ll:

  1. Generate training data: Latin Hypercube sampling in \((Q_0, a)\) space; compute summary statistics from N-body simulations using your JAX-native final-project simulator

  2. Build and train an emulator: MLP in Equinox; train ensemble with Optax

  3. Evaluate accuracy: Test metrics, predicted vs. true plots, uncertainty behavior

  4. Perform inference: NumPyro model with emulator as forward model; NUTS sampling; corner plots showing parameter recovery

This workflow — expensive simulations \(\to\) fast surrogates \(\to\) Bayesian inference — is exactly how frontier research operates. You’re learning it from the inside out.


Final Conceptual Checkpoint

Before starting the final project, ensure you can answer:

ImportantComprehensive Self-Assessment

Architecture & Theory:

  1. What is the role of activation functions? What happens without them?
  2. What does the Universal Approximation Theorem guarantee? What doesn’t it guarantee?
  3. How many parameters does a network with layers 5 \(\to\) 100 \(\to\) 100 \(\to\) 3 have?

Training:

  1. Explain how backpropagation relates to jax.grad and the autodiff you learned in the JAX module.
  2. Why must normalization use only training set statistics?
  3. Your training loss decreases but validation loss increases after epoch 200. What’s happening?
  4. What hyperparameters would you adjust if loss oscillates wildly?

Uncertainty:

  1. Why do different random seeds produce different trained networks?
  2. What type of uncertainty do ensembles capture? What don’t they capture?
  3. Where would you expect ensemble uncertainty to be highest?

Integration:

  1. In your NumPyro model, what role does the emulator play?
  2. How do you choose \(\sigma_{\rm obs}\) for the likelihood?
  3. How would you check if your inference is working correctly?

Practical:

  1. What does eqx.filter_jit do differently from jax.jit?
  2. What is the purpose of jax.vmap(model) in the training step?
  3. Why is the output layer linear (no activation) for your emulator?

Further Reading and Resources

Essential (Complete Before Final Project)

For Deeper Understanding

Emulation in Astrophysics

  • Cranmer et al. (2020): “The frontier of simulation-based inference”
  • Alsing et al. (2019): “Fast likelihood-free cosmology with neural density estimators and active learning”
    • Example of emulator-based inference in cosmology

Now go build something that learns.