— Richard Feynman (found on his blackboard at the time of his death, 1988)
Learning Objectives
By the end of this module, you will be able to:
ImportantRequired Path and Optional Depth
Required for the final project: Understand the emulator hierarchy, train/validation/test discipline, the 2-input MLP emulator, normalization with training-set statistics, baseline comparisons, held-out residual diagnostics, ensemble uncertainty, calibration basics, and how the emulator becomes the forward model inside a likelihood.
Optional depth: The theorem details, extended ecosystem examples, MC dropout, predicted variance heads, conformal prediction, Bayesian neural networks, and advanced inference discussion help you see the bigger picture. Use them after the baseline emulator is working and validated.
NoteYour Roadmap Through This Module
Core Question: Can a validated neural-network emulator replace repeated N-body simulations inside Bayesian inference for cluster initial conditions without destroying scientific trust?
The Big Picture: You’ve built N-body simulations in Project 2, and you’re now rebuilding that simulator in JAX for the final project. You’ve also built Bayesian inference engines in Project 4. Now you’ll learn to build fast approximations to expensive computations. This is the key to the final project: training a neural network to predict N-body simulation summaries, assigning uncertainty to those predictions, and using the emulator inside a likelihood for inference.
Section 1.1: The Emulation Problem Why we need fast surrogates for expensive simulations, and what exactly the emulator is approximating.
Section 1.2: From Linear Regression to Neural Networks The intellectual progression: linear models \(\to\) feature engineering \(\to\) learned features.
Section 1.3: The Computational Neuron The mathematical building blocks. Neurons, activations, and why nonlinearity matters.
Section 1.4: Layering Neurons into Networks Architecture design: hidden layers, parameter counting, and the expressiveness tradeoff.
Section 1.5: The Universal Approximation Theorem Why neural networks can (in principle) approximate any function. The profound idea and its limits.
Section 1.6: Forward Propagation How input becomes output. The computational graph perspective you know from the JAX module.
Section 1.7: Training as Optimization Loss functions, gradient descent, and the connection to likelihood maximization.
Section 1.8: Weight Initialization — Breaking Symmetry Why random initialization matters and how it enables ensembles.
Section 1.9: Practical Training Normalization, learning rates, convergence, multi-output handling, and debugging.
Section 1.10: Uncertainty via Ensembles Why single networks are dangerous, and how ensembles estimate one part of epistemic uncertainty.
Section 1.11: The JAX Ecosystem — Equinox and Optax Professional tools for building and training neural networks.
Section 1.12: From Emulator to Inference Constructing the physical-unit likelihood and connecting your trained network to NumPyro for Bayesian parameter recovery.
Section 1.13: Synthesis What you’ve learned and how it all connects.
TipVocabulary: AI, ML, Deep Learning, and Emulation
These words often get blurred together, so let’s separate them clearly:
Artificial Intelligence (AI): The broad umbrella for systems that perform tasks we associate with intelligence.
Machine Learning (ML): A subset of AI in which a model learns patterns from data instead of being programmed rule by rule.
Deep Learning: A subset of ML that uses multi-layer neural networks.
Scientific emulation: A specific ML use case in which we train a fast surrogate model to approximate an expensive scientific computation.
This reading is not trying to make you a general AI engineer. It is teaching one focused and useful workflow: how to build a trustworthy neural-network emulator for a scientific simulator, evaluate it honestly, and use it for inference.
1.1: The Emulation Problem
Priority: Essential
Why Emulators Matter
Emulator (or surrogate model): A fast approximation to an expensive computation. Trained on input-output pairs from the expensive model, then used in place of it.
Consider your N-body simulator. Each run takes seconds to minutes. But scientific inference requires thousands of forward-model evaluations. In direct inference, each likelihood evaluation may call the physical simulator. In emulator-based inference, each likelihood evaluation calls a trained surrogate. The question is whether the surrogate is accurate enough, calibrated enough, and limited to a trustworthy domain.
The computational bottleneck: Your final project asks you to infer what initial conditions \((Q_0, a)\) produced a given cluster outcome. If one simulation takes about 30 seconds, then 2000 likelihood evaluations would cost
The exact numbers depend on your simulator, hardware, timestep, and inference setup. The lesson is the scaling: direct simulation inside inference can become infeasible quickly.
The emulator solution: Train a neural network on a modest, validated simulation ensemble. The network learns a fast approximation to the simulator’s summary-statistic map. Now inference can become an interactive analysis step, provided the held-out validation evidence supports trusting the emulator.
TipPredict – Try – Explain: Why Emulate?
Predict: Which part of the final project becomes impossible if each likelihood evaluation reruns the N-body simulator?
Try: Write the two workflows side by side: direct inference calls the simulator; emulator-based inference calls the trained surrogate.
Explain: The emulator is not replacing physics. It is replacing repeated expensive evaluations only after the simulator and held-out emulator tests are credible.
What Exactly Is Being Emulated?
The final project is not “train a neural network because neural networks are powerful.” The scientific problem is an inverse problem. We want to infer which initial cluster conditions,
\[
\mathbf{x} = (Q_0, a),
\]
could have produced a measured or simulated cluster outcome,
Here \(f_{\rm bound}\) is dimensionless, \(\sigma_v\) has velocity units, and \(r_h\) has length units. That unit mismatch is not cosmetic. It is why we normalize outputs for training and why we must be careful when building a physical likelihood later.
Let \(\Phi_T\) denote the numerical N-body evolution from \(t = 0\) to final time \(T\), and let \(s\) denote the summary-statistic extractor applied to the final particle state. The physical forward model is the composition
What this equation is really saying: the network is not learning Newton’s laws directly. It is learning a compact approximation to the simulator’s input-output map over the parameter domain you sampled. Whether that approximation captures scientifically meaningful structure must be demonstrated by validation, residual analysis, and failure-mode checks.
Show code
flowchart TD x["Physical inputs\n(Q0, a)"] --> phi["N-body evolution Phi_T"] phi --> zT["Final particle state"] zT --> s["Summary extractor s"] s --> y["Summary vector\n(f_bound, sigma_v, r_h)"] y --> D["Training data"] D --> emu["MLP emulator f_theta(Q0, a)"] emu --> res["Residuals\ny - f_theta(x)"] res --> like["Likelihood\np(y_obs | Q0, a)"] like --> post["Posterior\np(Q0, a | y_obs)"]
flowchart TD
x["Physical inputs\n(Q0, a)"] --> phi["N-body evolution Phi_T"]
phi --> zT["Final particle state"]
zT --> s["Summary extractor s"]
s --> y["Summary vector\n(f_bound, sigma_v, r_h)"]
y --> D["Training data"]
D --> emu["MLP emulator f_theta(Q0, a)"]
emu --> res["Residuals\ny - f_theta(x)"]
res --> like["Likelihood\np(y_obs | Q0, a)"]
like --> post["Posterior\np(Q0, a | y_obs)"]
NoteConnection to Module 1: Models as Compression
Remember Module 1’s insight that models compress information? Summary statistics compress high-dimensional simulation outputs (particle positions and velocities) into meaningful numbers (\(f_{\rm bound}\), \(\sigma_v\), \(r_h\)). An emulator goes further — it compresses the simulator’s summary-statistic map into a set of neural-network weights.
This is learned compression: the network discovers a useful representation of the mapping over the sampled domain. Your approximately 4500 network parameters encode a fitted approximation to how initial virial ratio and scale radius map to the chosen cluster summaries. They do not encode the fundamental equations of motion in the way your simulator does.
Emulators in Modern Astrophysics
This isn’t a toy problem. Emulation is transforming computational science:
Field
Expensive forward model
Emulator predicts
Inference target
Cosmology
N-body or hydrodynamic simulations
matter power spectrum or summary statistics
cosmological parameters
Stellar evolution
MESA stellar tracks
luminosity, radius, lifetime, or \(T_{\rm eff}\)
age, mass, metallicity
Galaxy formation
large hydrodynamic simulations
galaxy properties from model parameters
feedback or cosmological parameters
Final project
cluster N-body simulation
\((f_{\rm bound}, \sigma_v, r_h)\)
\((Q_0, a)\)
The common pattern is not “replace physics with ML.” The common pattern is: trust a validated simulator, train a fast approximation to its summary outputs, and use that approximation only where validation supports it.
ImportantMinimum Emulator Workflow
If you feel overwhelmed by the vocabulary, keep this compact recipe in mind:
Build and validate your JAX-native simulator first.
Generate a modest training set over a clearly defined parameter range.
Compute summary statistics that will be your emulator targets.
Normalize inputs and outputs using training-set statistics only.
Train one small MLP as a baseline emulator.
Compare against mean, linear, and simple nearest-neighbor or interpolation-style baselines where feasible.
Add an ensemble for one component of epistemic uncertainty and check calibration.
Build a physical-unit likelihood and plug the emulator into NumPyro.
That is the core final-project lane. Everything else in this module is here to help you understand and defend those steps.
1.2: From Linear Regression to Neural Networks
Priority: Important
Before diving into neural network details, let’s see how they emerge naturally from ideas you already know.
The Progression of Function Approximation
Feature engineering: Manually constructing input transformations (like \(x^2\), \(\log x\), \(x_1 x_2\)) to capture nonlinear relationships. Traditional approach before neural networks automated this process.
Level 1: Linear Regression
The simplest model: \(\hat{y} = \mathbf{w}^T \mathbf{x} + b\)
You find weights \(\mathbf{w}\) and bias \(b\) that minimize squared error. This works beautifully when the true relationship is linear — but nature rarely cooperates.
Level 2: Polynomial/Basis Regression
To capture nonlinearity, you engineer features: \(\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \ldots\)
This is still linear in the parameters (the \(w_i\)), but nonlinear in the input. The problem: you must choose which features to include. For your emulator, should you use \(Q_0^2\)? \(\log a\)? \(Q_0 \times a\)? You’d have to guess.
Level 3: Neural Networks
Here’s the key insight: what if we learned the features?
A neural network with one hidden layer computes: \[\hat{y} = \mathbf{w}_{\rm out}^T \underbrace{\sigma(\mathbf{W}_{\rm hidden} \mathbf{x} + \mathbf{b}_{\rm hidden})}_{\text{learned features}} + b_{\rm out}\]
The hidden layer creates learned nonlinear features of the input. The output layer then does linear regression on these features. Training optimizes both the feature extraction and the final regression simultaneously.
NoteThree Levels of Function Approximation
Approach
What the human chooses
What the algorithm learns
Linear regression
The raw inputs
Weights on those inputs
Polynomial regression
Features such as \(x^2\) and \(x^3\)
Weights on chosen features
Neural network
Architecture and training data
Nonlinear features and weights together
The neural network does not remove human judgment. You still choose the data, architecture, loss, and validation evidence.
ImportantThe Profound Shift
Traditional statistics: Human chooses features \(\to\) algorithm finds weights Neural networks: Algorithm finds features and weights
This is why neural networks work across such diverse problems — they discover task-appropriate representations automatically. The same architecture that predicts cluster dynamics can (with different trained weights) recognize galaxies in images or denoise spectra.
NoteVisual Companion: What Is a Neural Network?
3Blue1Brown’s “But what is a neural network?” introduces the same basic architecture through handwritten-digit recognition: input activations, hidden layers, weights, biases, and output activations.
Lecture framing: 3Blue1Brown uses handwritten digits. Our problem is not classification; it is scientific emulation. As you watch, translate pixels to \((Q_0, a)\) and digit outputs to \((f_{\rm bound}, \sigma_v, r_h)\).
As you watch, translate the example into the final-project emulator:
Digit pixels become physical inputs \((Q_0, a)\).
Digit probabilities become summary-statistic predictions \((f_{\rm bound}, \sigma_v, r_h)\).
Image features become learned nonlinear features of the simulation design space.
The example is classification; your emulator is regression. The computational structure is the same kind of layered function.
Credit: 3Blue1Brown / Grant Sanderson
Where Does Your Emulator Fit?
Your task: learn the mapping \((Q_0, a) \to (f_{\rm bound}, \sigma_v, r_h)\).
Is this relationship linear? Almost certainly not. The physics of gravitational collapse, violent relaxation, and stellar escape involves complex nonlinear dynamics. You could try to engineer features based on virial theorem scaling relations — but why guess when you can learn?
A neural network will discover whatever nonlinear transformations of \((Q_0, a)\) are useful for predicting cluster outcomes. If it turns out linear features suffice, the network can learn that too (it includes linear models as a special case).
1.3: The Computational Neuron
Priority: Essential
Anatomy of a Neuron
Artificial neuron: A computational unit that computes a weighted sum of inputs, adds a bias, and applies a nonlinear activation function. Inspired by (but vastly simpler than) biological neurons.
The artificial neuron is the atomic unit of neural networks. It takes \(d\) inputs \(x_1, x_2, \ldots, x_d\) and produces one output:
where: - \(\mathbf{w} = (w_1, \ldots, w_d)\) are weights controlling how much each input matters - \(b\) is a bias allowing the neuron to shift its activation threshold - \(\sigma\) is a nonlinear activation function
Think of it as “how strongly should this neuron fire, given these inputs?” The weights determine sensitivity to each input; the bias sets the baseline; the activation function introduces nonlinearity.
Pre-activation: The quantity \(z = \mathbf{w}^T \mathbf{x} + b\) before the nonlinearity acts. Represents “how much this neuron wants to activate” before thresholding.
The quantity \(z = \mathbf{w}^T \mathbf{x} + b\) is called the pre-activation — the neuron’s response before the nonlinear squashing. The activation function then decides the actual output: \(a = \sigma(z)\).
NoteNeuron Computation Map
Input features \(x_1, x_2, \ldots, x_d\) are multiplied by weights \(w_1, w_2, \ldots, w_d\), summed with a bias \(b\), and passed through an activation: \[
z = \sum_i w_i x_i + b, \qquad a = \sigma(z).
\]
The important split is linear pre-activation followed by nonlinear activation.
Why Nonlinearity Is Essential
ImportantThe Collapse of Linear Networks
Without activation functions, a neuron computes a linear transformation: \(a = \mathbf{w}^T \mathbf{x} + b\).
No matter how many layers, you’d collapse to a single linear transformation. Deep networks would be no more powerful than shallow ones!
The nonlinearity \(\sigma\) prevents this collapse. It’s what allows networks to represent complex, nonlinear relationships — like how a cluster’s bound fraction depends on initial virial ratio through the physics of violent relaxation and escape.
Activation Functions
Several activation functions are commonly used:
ReLU (Rectified Linear Unit) — the modern default: \[\text{ReLU}(z) = \max(0, z) = \begin{cases} z & \text{if } z > 0 \\ 0 & \text{if } z \leq 0 \end{cases}\]
Beautifully simple: pass positive values unchanged, zero out negative values. Computationally cheap and works remarkably well. The gradient is 1 for \(z > 0\) and 0 for \(z < 0\), which helps training (no “vanishing gradients” for positive inputs).
Sigmoid (the “Fermi function” you know from statistical mechanics): \[\sigma(z) = \frac{1}{1 + e^{-z}}\]
Squashes outputs to \((0, 1)\). Useful for probabilities, but can cause vanishing gradients in deep networks (derivative is small when \(|z|\) is large).
Squashes to \((-1, 1)\). Zero-centered, which can help optimization. Same vanishing gradient concern as sigmoid.
NoteActivation Function Comparison
Activation
Range
Useful intuition
Main caution
ReLU
\([0, \infty)\)
Cheap, sparse, strong default for hidden layers
Zero gradient for inactive units
Sigmoid
\((0, 1)\)
Probability-like output
Saturates for large \(\lvert z \rvert\)
Tanh
\((-1, 1)\)
Zero-centered sigmoid-like shape
Also saturates for large \(\lvert z \rvert\)
Historical note: Early neural networks (1980s–90s) used sigmoid and tanh. ReLU was popularized around 2010 and revolutionized deep learning by alleviating the vanishing gradient problem. Sometimes the simplest solutions win.
For your emulator: Use ReLU for hidden layers. It’s the standard choice for regression and works well in practice.
1.4: Layering Neurons into Networks
Priority: Essential
The Multi-Layer Perceptron
Multi-Layer Perceptron (MLP): A neural network architecture consisting of an input layer, one or more hidden layers, and an output layer. Also called a “feed-forward” network because information flows in one direction (no loops).
A single neuron has limited expressiveness. The power comes from layering neurons into a network:
Input layer: Your raw features \(\mathbf{x} \in \mathbb{R}^{d_{\text{in}}}\). Not really “neurons” — just the data entering the network.
Hidden layer(s): Where the magic happens. Each layer transforms its input through weights, biases, and activations. A hidden layer with \(n\) neurons computes:
where \(\mathbf{W} \in \mathbb{R}^{n \times d_{\text{in}}}\) is the weight matrix, \(\mathbf{b} \in \mathbb{R}^n\) is the bias vector, and \(\sigma\) applies element-wise.
Output layer: Produces final predictions \(\hat{\mathbf{y}} \in \mathbb{R}^{d_{\text{out}}}\). For regression, typically no activation (linear output).
Your Emulator Architecture
For the final project, you’ll build: - Input: 2 features \((Q_0, a)\) - Hidden 1: 64 neurons, ReLU activation - Hidden 2: 64 neurons, ReLU activation - Output: 3 predictions \((f_{\text{bound}}, \sigma_v, r_h)\), linear (no activation)
Units reminder: Ensure \(\sigma_v\) uses units consistent with your JAX simulator (typically km/s or your chosen internal N-body units). The scale radius \(a\) and half-mass radius \(r_h\) should use the same length convention throughout the pipeline.
Show code
flowchart LR subgraph Input["Input Layer (2)"] Q0["Q0"] a["a"] end subgraph H1["Hidden Layer 1\n64 neurons, ReLU"] h1["h1 layer 1"] h2["..."] h3["h64 layer 1"] end subgraph H2["Hidden Layer 2\n64 neurons, ReLU"] h4["h1 layer 2"] h5["..."] h6["h64 layer 2"] end subgraph Output["Output Layer (3)\nLinear"] f["f_bound"] s["sigma_v"] r["r_h"] end Q0 --> H1 a --> H1 H1 --> H2 H2 --> Output
flowchart LR
subgraph Input["Input Layer (2)"]
Q0["Q0"]
a["a"]
end
subgraph H1["Hidden Layer 1\n64 neurons, ReLU"]
h1["h1 layer 1"]
h2["..."]
h3["h64 layer 1"]
end
subgraph H2["Hidden Layer 2\n64 neurons, ReLU"]
h4["h1 layer 2"]
h5["..."]
h6["h64 layer 2"]
end
subgraph Output["Output Layer (3)\nLinear"]
f["f_bound"]
s["sigma_v"]
r["r_h"]
end
Q0 --> H1
a --> H1
H1 --> H2
H2 --> Output
Counting Parameters
Learnable parameters: The weights and biases that training optimizes. The “knobs” the network adjusts to fit data.
How many learnable parameters does your network have? This matters for understanding model complexity.
WarningConnection to Module 1: Degrees of Freedom and Overfitting
With 4,547 parameters and only ~100 training examples, alarm bells should ring. You have ~45\(\times\) more parameters than data points!
In Module 1, you learned that model complexity must balance against data availability. Too many degrees of freedom and you fit noise, not signal — the model memorizes training data but fails on new inputs.
Why doesn’t this automatically doom neural networks? Several factors help, but none of them remove the need for validation:
Implicit regularization: Gradient descent with early stopping tends to find “simple” solutions
Low-dimensional input domain: The map from \((Q_0, a)\) to three summaries is much smaller than an image-recognition problem
Expected smoothness: Nearby initial conditions often produce nearby summary statistics, except near real dynamical transitions
Architecture restraint: This MLP is small relative to modern deep networks
Validation: We monitor validation performance to detect overfitting while keeping the test set untouched
Normalization and early stopping: These keep optimization better behaved
Still, with limited data, you should expect some overfitting risk. The emulator is scientifically valid only inside the parameter domain covered by the training design, and only where held-out residuals and calibration checks support it.
1.5: The Universal Approximation Theorem
Priority: Important
The Theoretical Foundation
Universal Approximation Theorem: First proved by Cybenko (1989) for sigmoid activations; later extended to ReLU and other activations by Hornik, Leshno, and others.
Here’s a remarkable mathematical result that justifies using neural networks as flexible function approximators:
Theorem (Universal Approximation): Let \(f: K \to \mathbb{R}\) be any continuous function defined on a compact set \(K \subset \mathbb{R}^d\). For any \(\epsilon > 0\), there exists a neural network \(\hat{f}\) with a single hidden layer and sufficiently many neurons such that:
In words: any continuous function can be approximated arbitrarily well by a neural network with one hidden layer, if you use enough neurons.
What this means for emulation: Your N-body simulator defines a continuous function from initial conditions to summary statistics (assuming the summary statistics vary continuously with inputs, which they do for your problem). The theorem guarantees that some neural network can approximate this mapping to any desired accuracy.
The Intuition: Building Functions from Simple Pieces
How can a network approximate arbitrary functions?
For sigmoid activations: Each neuron creates a “soft step” — a smooth transition from low to high output. By positioning many soft steps at different locations with different heights, you can approximate any shape. It’s like building a sculpture from many small clay pieces, each contributing a local bump or dip.
For ReLU activations: Each neuron creates a “hinge” — a bend in the function at some location. A network with \(n\) ReLU neurons can create a piecewise linear function with up to \(n\) hinges. With enough hinges, piecewise linear functions can approximate any continuous curve.
NoteReLU Approximation Intuition
Picture a smooth target curve, then approximate it with more and more straight-line segments. A few ReLU neurons make a crude piecewise-linear sketch; more neurons add more hinges; enough well-placed hinges can track the curve closely over the training domain.
The phrase “over the training domain” matters. This theorem does not promise sensible extrapolation outside the sampled range.
The Crucial Caveats
WarningWhat the Theorem Does NOT Guarantee
1. How many neurons you need: The theorem says “sufficiently many” but the number could be astronomically large for complex functions. It’s an existence theorem, not a construction.
2. How to find the right weights: Knowing a good network exists doesn’t tell you how to train it. Optimization might get stuck in poor local minima.
3. That training data suffices: The theorem assumes you can evaluate the true function anywhere. In practice, you only have finite samples. Generalization from samples to the full function is a separate challenge.
4. Behavior outside the training region: The theorem applies on a compact set \(K\). Outside this region, the network extrapolates with no guarantees.
Universal approximation \(\neq\) universal learning. The theorem speaks to expressiveness (what networks can represent), not learnability (what we can find from data).
For the final project, the theorem only tells us that a neural network could approximate the simulator summary map on the sampled \((Q_0, a)\) domain. It does not tell us that your training set is dense enough, that Adam found the right approximation, that the emulator respects cluster dynamics outside the training range, or that the resulting posterior will be calibrated.
1.6: Forward Propagation
Priority: Essential
Computing Predictions
Forward propagation (or forward pass): The process of computing a network’s output from its input by applying each layer’s transformation sequentially.
Forward propagation is how we compute predictions: apply each layer’s transformation in sequence, feeding outputs forward.
For your 2-hidden-layer network, the forward pass computes:
Each layer takes the previous layer’s output, applies a linear transformation (matrix multiply + bias), and passes through an activation (except the final layer).
NoteWhy No Activation on the Output Layer?
For regression tasks (predicting continuous values like \(f_{\text{bound}}\), \(\sigma_v\), \(r_h\)), the output layer should be linear — no activation function.
Activations like sigmoid squash outputs to bounded ranges \((0, 1)\). If your targets can span ranges like \(r_h \in [20, 150]\) AU, a sigmoid would artificially constrain predictions. ReLU would prevent negative predictions, which might be fine for \(f_{\rm bound} \geq 0\) but problematic if residuals can go negative.
For classification tasks, you’d use sigmoid (binary) or softmax (multiclass) to produce probabilities. But your emulator does regression, so keep the output linear.
The Computational Graph Perspective
Computational graph: A directed acyclic graph representing how computations compose. Nodes are operations or values; edges show data flow. JAX builds these automatically when tracing functions.
Forward propagation defines a computational graph — exactly the structure JAX uses for automatic differentiation.
This graph shows exactly how the output \(\hat{\mathbf{y}}\) depends on the input \(\mathbf{x}\) and all parameters \(\boldsymbol{\theta} = (\mathbf{W}_1, \mathbf{b}_1, \mathbf{W}_2, \mathbf{b}_2, \mathbf{W}_3, \mathbf{b}_3)\).
When we train, we need gradients \(\nabla_{\boldsymbol{\theta}} \mathcal{L}\). The computational graph tells JAX exactly how to compute them via the chain rule — that’s what jax.grad does automatically.
Conceptual Checkpoint: Architecture Foundations
Before moving to training, verify your understanding:
ImportantSelf-Assessment
Nonlinearity: If you removed all ReLU activations from your network, what class of functions could it represent? Why is this limiting?
Parameter count: A network has layers of sizes 10 \(\to\) 50 \(\to\) 50 \(\to\) 5. How many total parameters (weights + biases)?
Forward pass: Given input \(\mathbf{x} = [0.5, 1.0]^T\) and first layer weights \(\mathbf{W}_1 = \begin{pmatrix} 1 & -1 \\ 0 & 2 \end{pmatrix}\) with bias \(\mathbf{b}_1 = [0, -1]^T\), compute \(\mathbf{h}_1\) (the first hidden layer output, with ReLU).
Architecture choice: Why does your emulator use ReLU for hidden layers but no activation for the output layer?
TipCheck Your Answers
Only linear functions \(\hat{\mathbf{y}} = \mathbf{A}\mathbf{x} + \mathbf{c}\). Limiting because most physical relationships are nonlinear.
ReLU provides nonlinearity so hidden layers can learn complex features. Output layer is linear because we’re doing regression — predictions should be unconstrained real numbers.
1.7: Training as Optimization
Priority: Essential
The Loss Function
Loss function (or cost function, objective function): A scalar function measuring how wrong the network’s predictions are. Training minimizes this function.
Training requires a measure of “how wrong” the network is. For regression, the standard choice is mean squared error (MSE):
where \(\boldsymbol{\theta}\) denotes all network parameters, \(N\) is the number of training examples, and \(\hat{\mathbf{y}}_i = f_{\text{network}}(\mathbf{x}_i; \boldsymbol{\theta})\) is the prediction for input \(\mathbf{x}_i\).
The loss is a function of the parameters: different weights produce different predictions, which produce different errors. Training finds parameters that minimize the loss.
These three quantities are related, but they are not interchangeable:
Quantity
Used for
Example
Training loss
fitting network weights
MSE on normalized outputs
Validation metric
detecting generalization failure
RMSE on held-out simulations
Scientific likelihood
inference over physical parameters
Gaussian likelihood in physical output units
In Bayesian inference, you maximize or sample from a posterior. In neural-network training, you minimize a loss. These are connected when the loss can be derived from a likelihood, but the interpretation depends on the assumptions.
For a multi-output emulator, a simple diagonal Gaussian residual model says
If all outputs are normalized and given equal unit variance for training, minimizing MSE is equivalent to maximizing this Gaussian likelihood under a fixed equal-variance residual model in normalized space. That is useful for optimization. It is not automatically the same as the physical likelihood you will use in NumPyro.
The physical likelihood must compare physical-unit predictions to physical-unit observations, with uncertainties in the same units as the outputs.
Gradient Descent
Gradient descent: An iterative optimization algorithm that updates parameters in the direction opposite to the gradient (direction of steepest increase), thereby decreasing the objective function.
We want \(\boldsymbol{\theta}^* = \arg\min_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta})\). Gradient descent iteratively steps toward the minimum:
where \(\eta > 0\) is the learning rate — how large a step we take.
Learning rate (\(\eta\)): The step size in gradient descent. Controls the tradeoff between speed (large \(\eta\)) and stability (small \(\eta\)). Choosing it well is crucial.
The gradient\(\nabla_{\boldsymbol{\theta}} \mathcal{L}\) points in the direction of steepest increase in loss. We move in the opposite direction to decrease the loss.
TipPredict – Try – Explain: Learning Rate
Predict: If the learning rate is too large, what will the loss curve look like?
Try: Sketch three loss curves: too large, too small, and reasonable. Label what you would change in each case.
Explain: The learning rate is not just a speed knob. It controls whether the optimizer can settle into a useful minimum at all.
NoteVisual Companion: Gradient Descent
3Blue1Brown’s “Gradient descent, how neural networks learn” is a good visual bridge from this equation to training behavior. The digit-recognition example has many more inputs than our emulator, but the optimization story is the same: define a loss, compute which direction increases it, then step the parameters in the opposite direction.
For this course, keep one extra scientific habit in mind: decreasing training loss is not enough. You still need validation loss, held-out evaluation, and a baseline comparison before trusting the emulator.
Credit: 3Blue1Brown / Grant Sanderson
NoteConnection to Module 5: Optimization vs. Sampling
You’ve now seen three approaches to finding good parameters:
Approach
What It Finds
Module
Analytical
Exact optimum (when possible)
Linear regression
Sampling (MCMC)
Full posterior distribution
Module 5, Project 4
Optimization (GD)
Single point estimate (mode)
This reading
Gradient descent finds one answer — the parameter values at the (local) minimum. MCMC explores the full posterior — the entire distribution of plausible parameters.
For inference, we often want the full posterior. For training emulators, we usually just want a good point estimate. But we’ll recover uncertainty via ensembles (Section 1.10).
Backpropagation = Chain Rule = Autodiff
Backpropagation: The algorithm for computing gradients of the loss with respect to all network parameters. Mathematically, it’s the chain rule applied systematically through the computational graph.
Backpropagation computes gradients of the loss with respect to all network parameters. The key insight: it’s just the chain rule applied systematically.
Consider how the loss depends on first-layer weights:
The chain rule propagates gradients backward from the loss through each layer to the parameters.
NoteVisual Companion: What Backpropagation Is Doing
3Blue1Brown’s “What is backpropagation really doing?” is the best place to build intuition before staring at code. Watch for the central idea: each weight and bias gets a sensitivity score telling us how changing it would change the loss.
In your emulator, those sensitivities answer a concrete question: how should the weights of the MLP change so that predictions of \(f_{\rm bound}\), \(\sigma_v\), and \(r_h\) better match the simulator outputs?
Credit: 3Blue1Brown / Grant Sanderson
TipConnection to the JAX Module: You Already Know This!
In the JAX module, you learned that JAX computes gradients via automatic differentiation — specifically, reverse-mode autodiff. Backpropagation is exactly this: reverse-mode autodiff applied to neural networks.
JAX is performing backpropagation automatically. It traces the forward pass to build the computational graph, then applies the chain rule in reverse during grad().
You don’t implement backprop by hand. JAX does it for you. But understanding what it computes helps you debug and reason about training dynamics.
NoteOptional Depth: Backpropagation Calculus and JAX
3Blue1Brown’s “Backpropagation calculus” shows the chain-rule bookkeeping more explicitly. This is optional depth, but it is exactly the math that JAX automates when you call jax.grad or eqx.filter_value_and_grad.
The important connection is this: JAX makes backpropagation practical because it records the computational graph of array operations and applies reverse-mode automatic differentiation to that graph. You write the forward model and scalar loss; JAX constructs the derivative computation needed by gradient descent.
Credit: 3Blue1Brown / Grant Sanderson
The Adam Optimizer
Adam (Adaptive Moment Estimation): An optimizer that adapts learning rates for each parameter based on gradient history. Tracks running averages of gradients (first moment) and squared gradients (second moment).
Plain gradient descent uses a fixed learning rate for all parameters. This can be suboptimal — some parameters benefit from larger steps, others from smaller.
Adam adapts the learning rate per parameter based on gradient history:
First moment (exponential moving average of gradients): Which direction have we been moving?
Second moment (exponential moving average of squared gradients): How variable have gradients been?
Parameters with consistently large gradients in the same direction get boosted. Parameters with noisy, inconsistent gradients get damped.
In practice: Adam is the default optimizer for neural networks. It works well across problems with minimal tuning. Start with learning rate \(\eta = 10^{-3}\).
Optax provides Adam and many other optimizers.
1.8: Weight Initialization — Breaking Symmetry
Priority: Important
Why Initialization Matters
Before training begins, you must set initial values for all weights and biases. This choice profoundly affects training dynamics.
WarningThe Symmetry Problem
What if you initialize all weights to zero?
Every neuron in a layer computes the same function (weighted sum of inputs with weights 0 \(\to\) output 0). They all produce identical outputs. During backprop, they all receive identical gradients. They all update identically. They stay identical forever.
You’ve effectively reduced your 64-neuron layer to a single neuron. The network can’t learn complex features because all neurons are copies of each other.
Random initialization breaks this symmetry. Different neurons start with different weights, compute different functions, receive different gradients, and specialize to detect different features.
Scale Matters Too
Even with random initialization, the scale of initial weights affects training:
Weights too large: Pre-activations \(z = \mathbf{W}\mathbf{x} + \mathbf{b}\) become large. For sigmoid/tanh, this saturates the activation (gradient \(\approx\) 0). For ReLU, values may explode through layers.
Weights too small: Signals shrink as they pass through layers. By the time they reach the output, everything is near zero. Gradients vanish.
Xavier/Glorot initialization: Initialize weights with variance \(\text{Var}(w) = \frac{2}{n_{\text{in}} + n_{\text{out}}}\) where \(n_{\text{in}}\), \(n_{\text{out}}\) are input/output dimensions of the layer. Keeps signal magnitudes stable across layers.
Xavier/Glorot initialization (what Equinox uses by default) sets the right scale:
This keeps the variance of activations roughly constant across layers, enabling stable gradient flow.
Why Different Seeds \(\to\) Different Solutions
Neural network loss landscapes are non-convex — they have many local minima and saddle points. Gradient descent finds a minimum, not the global minimum. Which minimum you find depends on where you start.
NoteNon-Convex Loss Landscape
Imagine two starting points on a hilly loss surface. Both can roll downhill, both can reach low loss, and they can still end in different valleys. Neural-network training is like this: different random initializations can produce different but similarly accurate functions.
Different random seeds produce different initial weights \(\to\) different optimization trajectories \(\to\) different final solutions.
These solutions may all have similar training loss but differ in their predictions, especially in regions with sparse data. This is the foundation for ensemble uncertainty (Section 1.10): train multiple networks with different seeds and measure how much they disagree.
1.9: Practical Training
Priority: Essential
Data Normalization
Standardization (z-score normalization): Transform each feature to zero mean and unit variance: \(\tilde{x} = (x - \mu)/\sigma\). Essential for neural network training.
Neural networks are sensitive to input/output scales. If \(Q_0 \in [0.5, 1.5]\) but \(a \in [50, 200]\), gradients with respect to \(a\)-related weights will dominate, causing slow and unstable learning.
Standardization transforms each feature to zero mean and unit variance:
\[\tilde{x} = \frac{x - \mu_x}{\sigma_x}\]
where \(\mu_x\) and \(\sigma_x\) are computed from your training set only.
WarningThe Cardinal Rule of Normalization
Compute \(\mu\) and \(\sigma\) from the training set only. Apply the same transformation (using training statistics) to:
Training inputs and outputs
Validation/test inputs and outputs
Any new data at inference time
Never use test set statistics for normalization — that would be data leakage, using information from data you’re supposed to predict.
Denormalization: To convert predictions back to physical units: \[x = \tilde{x} \cdot \sigma_x + \mu_x\]
Multi-Output Normalization
Multi-output regression: Predicting multiple target variables simultaneously. Each output should typically be normalized independently.
Your emulator predicts three quantities: \((f_{\text{bound}}, \sigma_v, r_h)\). The MSE loss sums over all outputs:
Problem: If outputs have different scales, this implicitly weights them unequally. If \(r_h\) varies from 20–150 AU (range 130) while \(f_{\text{bound}}\) varies from 0.3–1.0 (range 0.7), squared errors for \(r_h\) are ~35,000\(\times\) larger. The network will focus on \(r_h\) at the expense of \(f_{\text{bound}}\).
Solution: Normalize each output independently to zero mean, unit variance. Then all three contribute roughly equally to the loss.
The normalization pattern (you’ll implement this):
NORMALIZATION PROCEDURE
=======================
1. COMPUTE STATISTICS (from training set only!)
mu = mean of training data, per feature # Shape: (3,) for outputs
sigma = std of training data, per feature # Shape: (3,) for outputs
2. NORMALIZE any data (train, test, or new):
x_tilde = (x - mu) / sigma
3. DENORMALIZE predictions back to physical units:
x = x_tilde * sigma + mu
Critical: Compute \(\mu\) and \(\sigma\) from training data, then use those same values everywhere — for normalizing test data, for denormalizing predictions, and inside your NumPyro model. Store these statistics alongside your trained model.
WarningNormalization and Likelihood Consistency
Training happens most cleanly in normalized units. Scientific interpretation usually happens in physical units. Do not mix them.
If the network predicts normalized outputs, then the inference pipeline must either:
denormalize predictions before comparing to physical observations, or
transform the observation vector and covariance into the same normalized units.
Then evaluate the likelihood using \(\mathbf{y}_{\rm pred}\), \(\mathbf{y}_{\rm obs}\), and uncertainties in physical units. Mixing normalized predictions with physical uncertainties is wrong, even if the code runs.
TipPredict – Try – Explain: Normalization Leakage
Predict: What changes if you compute the mean and standard deviation using all data before the split?
Try: Write down which future information leaks into training when test examples help define the normalization.
Explain: Leakage can make an emulator look more accurate than it really is. Training statistics are part of the trained model; validation and test data must be transformed with those fixed values.
Hyperparameters
Hyperparameter: A setting that controls training but isn’t learned from data. Must be chosen before training begins.
Epoch: One complete pass through all training data. If you have 100 training examples and train for 500 epochs, the network sees each example 500 times.
Key hyperparameters for your emulator:
Hyperparameter
Starting Value
Adjustment Guidance
Learning rate
\(10^{-3}\)
Decrease by 10\(\times\) if loss oscillates; increase if loss decreases too slowly
Epochs
500–1000
Until loss plateaus on validation set
Hidden layers
2
More layers rarely help for smooth, low-dimensional functions
Neurons per layer
64
Try 32 or 128 if 64 seems too small/large
Batch size
Full batch for modest datasets
Mini-batches matter only once the dataset is large enough that full-batch updates become slow
Train/Validation/Test Split
With limited simulations, protect three roles:
Training set: Fit model parameters
Validation/calibration set: Choose hyperparameters, stop training, and estimate likelihood widths
Test set: Final held-out report after decisions are fixed
If your dataset is very small, a practical starting point is roughly 70/15/15. If that leaves too few test cases, use repeated random splits or cross-validation for development, but still keep a small final test set untouched for reporting.
WarningThe Test Set Is Sacred
Never use test data for any training decisions:
Don’t tune hyperparameters based on test performance
Don’t normalize using test statistics
Don’t stop training based on test loss
The test set exists solely to evaluate your final model. If you peek at it during development, you lose your ability to honestly assess generalization.
Convergence Criteria
Training is “done” when:
Loss has plateaued: Not decreasing meaningfully over ~100 epochs
Loss is reasonably small: MSE \(\lesssim 0.01\) on normalized data means typical errors < 10% of standard deviation
No pathologies: No NaN values, no loss increasing
The Training Loop
Training loop: The iterative process of computing predictions, calculating loss, computing gradients, and updating parameters. Repeats for many epochs until convergence.
A typical training loop:
# Pseudocodelosses = []for epoch inrange(num_epochs):# Forward pass: compute predictions predictions = model(params, x_train)# Compute loss loss = jnp.mean((predictions - y_train)**2) losses.append(loss)# Backward pass: compute gradients gradients = jax.grad(loss_fn)(params)# Update parameters params = update_with_adam(params, gradients)if epoch %100==0:print(f"Epoch {epoch}: loss = {loss:.6f}")# ALWAYS plot the loss curveplt.plot(losses)plt.xlabel("Epoch")plt.ylabel("Loss")plt.yscale("log")
Always plot the loss curve. Its shape reveals problems that printed numbers hide.
NoteLoss Curve Diagnostics
Curve shape
Likely diagnosis
First response
Smooth decrease, then plateau
Healthy training
Stop when validation loss also plateaus
Large oscillations or divergence
Learning rate too high
Decrease learning rate by 10\(\times\)
Very slow decrease
Learning rate too low
Increase learning rate or train longer
Training loss falls while validation loss rises
Overfitting
Early stop, simplify, or add data
Debugging Checklist
ImportantWhen Training Goes Wrong
Before training — verify setup:
During training — monitor health:
After training — validate results:
Common fixes:
Symptom
Likely Cause
Fix
Loss stays constant
Learning rate too small
Increase by 10\(\times\)
Loss oscillates wildly
Learning rate too large
Decrease by 10\(\times\)
Loss becomes NaN
Numerical instability
Lower LR, check for division by zero
Loss plateaus high
Architecture too simple, or bug
Add neurons/layers, check code
Train loss good, test loss bad
Overfitting
Early stopping, more data, or accept it
1.10: Uncertainty via Ensembles
Priority: Essential
The Problem with Single Networks
A trained neural network gives point predictions — single numbers with no uncertainty. But scientific applications need uncertainty quantification:
How confident is this prediction?
Are we extrapolating outside the training distribution?
How should prediction uncertainty propagate to inference?
A single network provides none of this. Worse, networks can be confidently wrong, especially when extrapolating beyond training data. An ensemble helps, but ensemble spread is not automatically the total uncertainty needed for inference.
Deep Ensembles
Ensemble: A collection of models that make independent predictions. Disagreement among members quantifies uncertainty. For neural networks, different random initializations suffice to create an ensemble.
The simplest approach to neural network uncertainty: train multiple networks with different random seeds.
Why does this work? From Section 1.8: different initializations lead to different local minima. These solutions may give similar predictions where training data is dense but diverge where data is sparse.
Procedure:
Train \(M\) networks with different random seeds (typically \(M = 3\)–5)
For any input \(\mathbf{x}\):
Get predictions from each: \(\hat{y}_1(\mathbf{x}), \ldots, \hat{y}_M(\mathbf{x})\)
Mean prediction: \(\bar{y}(\mathbf{x}) = \frac{1}{M} \sum_{m=1}^{M} \hat{y}_m(\mathbf{x})\)
Epistemic uncertainty: Uncertainty due to limited knowledge — here, limited training data. Reducible with more data. Distinct from aleatoric uncertainty (irreducible noise in the process itself).
The spread \(\sigma_y\) estimates one part of epistemic uncertainty — uncertainty from limited training data and training trajectory. It is a warning light, not a complete error budget.
TipPredict – Try – Explain: Ensemble Uncertainty
Predict: Where should an ensemble disagree most: near dense training points, near domain edges, or far outside the sampled region?
Try: For one output such as \(f_{\rm bound}\), plot a one-dimensional slice in \(Q_0\) at fixed \(a\): training points, each ensemble member, the ensemble mean, and a shaded \(\pm 2\sigma\) band.
Explain: Ensemble spread is a warning light for epistemic uncertainty. It is most useful when you compare it to where the training data actually live.
Interpreting Ensemble Spread
Uncertainty should be high:
Near edges of training distribution (extrapolation begins)
In gaps between training points
Where the function is complex/rapidly varying
Uncertainty should be low:
In regions densely covered by training data
Where all ensemble members agree
Where the function is simple/slowly varying
NoteConnection to Module 5: Epistemic vs. Aleatoric
Remember from Module 5:
Epistemic uncertainty is reducible with more data. It represents “we’re uncertain because we haven’t seen enough examples.” Ensemble spread captures this — with more training data, ensemble members would agree better.
Aleatoric uncertainty is irreducible noise in the process itself. For your emulator, this can include scatter from different random cluster realizations at fixed macro-parameters \((Q_0, a)\). The exact same initial conditions and deterministic code path should reproduce the same result; the scatter comes from changing the realization, not from rerunning an identical state.
Ensembles capture epistemic uncertainty but not aleatoric uncertainty. The ensemble members all learned from the same training data, so they share the same aleatoric floor. More sophisticated approaches (predicting variance directly, Bayesian neural networks) can capture both, but ensembles are a practical first step.
WarningEnsemble Spread Is Not Total Uncertainty
For inference, a more honest uncertainty budget is
Here \(\Sigma_{\rm obs}\) represents observational or target-measurement uncertainty, \(\Sigma_{\rm emu}\) represents emulator approximation uncertainty, and \(\Sigma_{\rm sim}\) represents simulator stochasticity or numerical uncertainty. Ensemble spread can help estimate part of \(\Sigma_{\rm emu}\), but it does not automatically include the other terms. It must be checked against held-out residuals.
Predict: If your emulator is accurate in the center of parameter space but poor near the edges, what happens to the posterior over \((Q_0, a)\) if you use an unrealistically tiny uncertainty?
Try: Sketch a likelihood surface with a narrow peak in the wrong location.
Explain: An overconfident emulator can produce an overconfident posterior. Calibration is what prevents a fast approximation from becoming a fast wrong answer.
Ensemble in Practice
The ensemble workflow is conceptually simple:
ENSEMBLE WORKFLOW
=================
Training:
1. Split random key into M independent keys
2. Train M models, each with a different key
3. Store all trained models
Prediction:
1. Get predictions from all M models
2. Stack predictions into array of shape (M, num_outputs)
3. Mean = jnp.mean(predictions, axis=0)
4. Std = jnp.std(predictions, axis=0, ddof=1) # ddof=1 for unbiased
Implementation hints:
jax.random.split(key, M) creates M independent keys
Store models in a Python list
For single input: preds = jnp.array([model(x) for model in models])
For batched inputs: use jax.vmap inside the list comprehension
For your final project, \(M = 5\) ensemble members is a good balance between computational cost and uncertainty quality. You’ll implement this yourself — the pattern above tells you what to compute, not how to code it.
Conceptual Checkpoint: Training and Uncertainty
Before moving to implementation, verify your understanding:
ImportantSelf-Assessment
Loss interpretation: If your MSE loss on normalized data is 0.04, roughly how large are typical prediction errors relative to output standard deviations?
Learning rate: Describe what happens to the loss curve if the learning rate is (a) too large, (b) too small, (c) well-chosen.
Normalization: Why must you use training set statistics (not test set) when normalizing test data?
Multi-output: Your emulator predicts three outputs. Before normalization, \(r_h\) has 100\(\times\) larger variance than \(f_{\rm bound}\). What problem does this cause, and how do you fix it?
Ensembles: You train 5 networks on the same data but with different random seeds. Why do they converge to different solutions?
Uncertainty interpretation: Your ensemble predicts \(r_h = 85 \pm 3\) AU for one input and \(r_h = 120 \pm 20\) AU for another. What does the larger uncertainty in the second case tell you?
TipCheck Your Answers
MSE = 0.04 means RMSE = 0.2. On normalized data (std = 1), typical errors are ~20% of one standard deviation — quite good.
Too large: loss oscillates wildly, may diverge. (b) Too small: loss decreases very slowly, may not reach good minimum in reasonable time. (c) Well-chosen: smooth decrease, then plateau.
Using test statistics would be data leakage — incorporating information from data you’re supposed to predict. The model must work with only training-time information.
Without normalization, the loss is dominated by \(r_h\) (larger squared errors). The network focuses on \(r_h\) at the expense of \(f_{\rm bound}\). Fix: normalize each output independently to unit variance.
The loss landscape has many local minima. Different initial weights lead to different optimization trajectories, ending at different minima with similar loss but different predictions.
The second prediction has higher epistemic uncertainty. The ensemble members disagree more, likely because that input region had sparser training data. You should be less confident in the second prediction.
1.11: The JAX Ecosystem — Equinox and Optax
Priority: Essential
Equinox: Neural Networks as PyTrees
Equinox: A JAX library for neural networks by Patrick Kidger. Models are PyTrees (nested containers JAX understands), enabling seamless use of jax.grad, jax.jit, jax.vmap.
Equinox makes neural networks fit naturally into JAX’s functional paradigm:
Models are PyTrees: Parameters live in a tree structure. JAX transformations (grad, jit, vmap) work seamlessly because they know how to traverse PyTrees.
Models are callables: An Equinox model is a function. You call it: output = model(input).
Explicit state: No hidden mutable variables. All parameters are explicit. This makes debugging easier and enables JAX’s transformations.
Learning from Documentation
ImportantYour Primary Resources
The final project requires you to learn Equinox and Optax from their documentation. This is intentional — reading documentation and adapting examples is a core professional skill.
Equinox (https://docs.kidger.site/equinox/):
“Getting Started”: How to define an eqx.Module with __init__ and __call__
“Train a Neural Network”: Complete training loop pattern you’ll adapt
Optax (https://optax.readthedocs.io/en/latest/):
“Quick Start”: How to create optimizers and apply updates
Your task: Work through these tutorials before starting your emulator. Run the examples. Modify them. Break them and fix them. Then adapt the patterns to your problem.
Glass-Box Exercise: Manual Forward Pass
Before using any library, make sure you understand what’s happening inside:
TipBuild It Yourself First
Implement a forward pass manually using only JAX primitives:
import jax.numpy as jnpdef manual_forward(params, x):"""Forward pass through a 2-64-64-3 MLP. params: dict with keys 'W1', 'b1', 'W2', 'b2', 'W3', 'b3' x: input array of shape (2,) Returns: output array of shape (3,) """# Layer 1: linear transformation + ReLU z1 = params['W1'] @ x + params['b1'] h1 = jnp.maximum(0, z1)# Layer 2: linear transformation + ReLU z2 = params['W2'] @ h1 + params['b2'] h2 = jnp.maximum(0, z2)# Output layer: linear only (no activation) y = params['W3'] @ h2 + params['b3']return y
This exercise shows that neural networks are just compositions of simple operations. Equinox handles the bookkeeping — but you should understand what it’s managing.
A Teaching Example: Learning \(\sin(x)\)
To illustrate Equinox patterns without solving your final project, let’s build a tiny network that learns the sine function. This is NOT your emulator architecture — adapt the concepts, don’t copy the code.
import jaximport jax.numpy as jnpimport equinox as eqximport optax# A SIMPLE 1D example - your emulator will differ!class TinySinNet(eqx.Module): layer1: eqx.nn.Linear layer2: eqx.nn.Lineardef__init__(self, key): k1, k2 = jax.random.split(key)self.layer1 = eqx.nn.Linear(1, 16, key=k1) # 1 inputself.layer2 = eqx.nn.Linear(16, 1, key=k2) # 1 outputdef__call__(self, x): x = jax.nn.relu(self.layer1(x))returnself.layer2(x)
WarningHow Your Emulator Differs
TinySinNet has 2 layers (1 \(\to\) 16 \(\to\) 1) with only one hidden layer.
Two hidden layers of 64 neurons each, not one layer of 16
3 outputs, not 1
ReLU activation after each hidden layer
Adapt the pattern (how to define eqx.Module, split keys, write __call__), not the specific architecture.
Worked Mini-Example: A Tiny End-to-End Regressor
The example below is intentionally small. It shows the full workflow on a toy regression problem: normalize data, define an Equinox model, train it with Optax, and convert predictions back to physical units.
NoteCode Literacy: Tiny End-to-End Regressor
Purpose: Show the full Equinox/Optax training pattern on a toy problem before you adapt it to the emulator.
Inputs and outputs: The input is a one-dimensional array \(x\); the target is \(\sin(3x)\); the output is a trained model plus a physical-unit RMSE.
Common pitfall: Do not copy the architecture into the final project. Copy the workflow: train-only normalization, vmap for batches, filtered gradients, optimizer updates, denormalized evaluation.
The model is just a callable function with explicit parameters.
Normalization happens outside the model and must be saved for later use.
jax.vmap(model) handles the batch dimension.
eqx.filter_value_and_grad and optax.adam give you the core training loop.
After training, you must denormalize predictions before interpreting them.
For your final project, the structure is the same. The differences are the scientific dataset, the emulator architecture, the held-out evaluation, and the uncertainty/inference layers you build on top.
Key patterns to notice (these transfer to your emulator):
Inherit from eqx.Module — makes your class a PyTree
Declare layers as typed attributes — Equinox tracks these
Split keys for each layer — ensures independent initialization
__call__ defines forward pass — apply layers and activations
Essential Equinox Patterns
These small patterns are the building blocks. You’ll combine them for your emulator.
# Activations are just JAX functionsx = jax.nn.relu(layer(x)) # ReLUx = jax.nn.tanh(layer(x)) # Tanh (if you needed it)
Pattern 3: Filtering arrays from models
# Equinox models contain arrays (parameters) and non-arrays (structure)# Many operations need just the arrays:params = eqx.filter(model, eqx.is_array)
Pattern 4: Computing gradients
# eqx.filter_value_and_grad works like jax.value_and_grad# but handles models with non-array componentsloss, grads = eqx.filter_value_and_grad(loss_fn)(model)
Pattern 5: Applying updates
# After computing updates from optimizer:new_model = eqx.apply_updates(model, updates)
Pattern 6: JIT-compiling with models
# Use eqx.filter_jit instead of jax.jit for functions involving eqx.Module@eqx.filter_jitdef train_step(model, ...): ...
The Training Loop Structure
Here’s the structure of a training loop — you’ll fill in the details:
TRAINING LOOP PSEUDOCODE
========================
1. INITIALIZE
- Create model with random key
- Create optimizer (e.g., optax.adam)
- Initialize optimizer state from model parameters
2. FOR EACH EPOCH:
a. FORWARD PASS
- Compute predictions: pred = model(x_train)
- (Use jax.vmap to handle batches)
b. COMPUTE LOSS
- MSE: loss = mean((pred - y_train)**2)
c. BACKWARD PASS
- Compute gradients w.r.t. model parameters
- (Use eqx.filter_value_and_grad)
d. UPDATE
- Get updates from optimizer
- Apply updates to model
e. LOG
- Store loss for plotting
- Print progress periodically
3. PLOT LOSS CURVE (always!)
4. RETURN trained model
Your task: Translate this pseudocode into working Equinox/Optax code. The documentation tutorials show exactly how.
Training an Ensemble: The Idea
For uncertainty quantification, you’ll train multiple networks:
ENSEMBLE PSEUDOCODE
===================
1. Split your random key into M keys (one per model)
2. FOR EACH key:
- Train a model using that key
- (Different key -> different initialization -> different solution)
- Store the trained model
3. TO PREDICT with uncertainty:
- Get prediction from each model
- Mean = average of predictions
- Uncertainty = standard deviation of predictions
Implementation hints:
jax.random.split(key, n) gives you n independent keys
Store models in a Python list
For ensemble prediction, iterate over models and stack results with jnp.array([...])
Use jnp.mean(..., axis=0) and jnp.std(..., axis=0, ddof=1) to aggregate
WarningWhat You Must Figure Out
The public final-project docs and starter patterns give you the workflow, but you still need to make the key implementation decisions. You need to:
Define your model class — How many layers? What dimensions? Where do activations go?
Write the training step — How do you combine eqx.filter_value_and_grad, optimizer.update, and eqx.apply_updates?
Handle batches — When do you use jax.vmap?
Train the ensemble — How do you manage multiple models with different initializations?
The Equinox “Train a Neural Network” tutorial shows a complete example. Study it, understand it, then adapt it to your simulator-and-emulator workflow — don’t copy it blindly.
TipSaving Your Trained Models
You’ll need to save trained models for use in the inference phase. Equinox provides serialization via eqx.tree_serialise_leaves and eqx.tree_deserialise_leaves.
Important: Save your normalization statistics (mean, std for inputs and outputs) alongside your models — you’ll need them for inference. A simple approach is to save everything in a dictionary or use separate files.
See the Equinox documentation on serialization for details.
1.12: From Emulator to Inference
Priority: Essential
The Goal
You’ve trained a fast emulator. Now use it for Bayesian inference: given observed or held-out cluster summary statistics, infer the initial conditions that produced them.
This posterior asks: after seeing the cluster summaries, which initial virial ratios and scale radii remain plausible?
The emulator replaces the expensive N-body simulation inside the likelihood. NumPyro’s NUTS sampler will call that likelihood thousands of times, so speed matters. Scientific trust, however, comes from the likelihood and uncertainty model, not from speed alone.
This diagonal assumption is not automatic. It says errors in bound fraction, velocity dispersion, and half-mass radius are independent after conditioning on the model. If held-out residuals show correlated errors, this likelihood is too simple.
Bayes’ theorem then gives the object we actually want:
\[
p(Q_0, a \,|\, \mathbf{y}_{\rm obs})
\propto
p(\mathbf{y}_{\rm obs} \,|\, Q_0, a)
p(Q_0, a).
\]
The prior \(p(Q_0, a)\) encodes physically plausible initial conditions and should also keep the sampler inside the domain where the emulator has scientific authority. The likelihood encodes agreement with the observed cluster summaries. The posterior is the set of initial conditions still plausible after seeing the data.
NoteLikelihood vs. Posterior
The likelihood and posterior condition in opposite directions:
Likelihood: \(p(\mathbf{y}_{\rm obs} \,|\, Q_0, a)\) asks how plausible the data are if the initial conditions are true.
Posterior: \(p(Q_0, a \,|\, \mathbf{y}_{\rm obs})\) asks how plausible the initial conditions are after seeing the data.
NumPyro samples the posterior. To do that, you must provide priors and a likelihood.
The Conceptual Framework
Probabilistic programming: A paradigm where you write a generative model (how data is produced from parameters), and the framework handles inference automatically.
In NumPyro, you write a generative model — code describing how observations arise from parameters. The structure directly mirrors Bayes’ theorem:
GENERATIVE MODEL STRUCTURE
==========================
1. PRIOR: What do we believe about parameters before seeing data?
- Sample Q0 from some distribution (e.g., Uniform over training range)
- Sample a from some distribution
2. FORWARD MODEL: Given parameters, what do we predict?
- Normalize the sampled (Q0, a) using TRAINING statistics
- Pass through your emulator to get predicted (f_bound, sigma_v, r_h)
- Denormalize predictions back to physical units
3. LIKELIHOOD: How probable are actual observations given predictions?
- Compare predictions to observations
- Account for observational, simulation, and emulator uncertainty
The key insight: Your emulator serves as the forward model. Instead of running an expensive N-body simulation for each \((Q_0, a)\) the sampler proposes, you call your fast neural network.
What You Need to Figure Out
ImportantNumPyro Learning Path
The final project requires you to build a NumPyro model. Here’s what to learn:
NumPyro documentation (https://num.pyro.ai/):
“Getting Started”: Basic model structure with numpyro.sample
“Bayesian Regression”: Example closest to your task
MCMC and NUTS: How to run inference
Key NumPyro primitives you’ll use:
numpyro.sample("name", distribution) — sample a parameter from a prior
numpyro.sample("obs", distribution, obs=data) — define likelihood given observed data
dist.Uniform(low, high) — uniform prior
dist.Normal(loc, scale) — Gaussian likelihood
NUTS(model) — the sampler (you know this from Module 5!)
MCMC(kernel, num_warmup, num_samples) — run the chain
The Normalization Challenge
Your emulator was trained on normalized data. Your NumPyro model must handle this correctly:
NORMALIZATION FLOW
==================
Physical parameters (Q0, a) from prior
|
Normalize using TRAINING statistics
|
Normalized inputs to emulator
|
Emulator predicts normalized outputs
|
Denormalize using TRAINING statistics
|
Physical predictions (f_bound, sigma_v, r_h)
|
Compare to physical observations in likelihood
WarningCritical Detail
You must use the same normalization statistics (mean and std from training set) that you used during training. Store these when you train your emulator, and pass them to your inference model.
A common bug: accidentally using different statistics, which silently produces wrong predictions.
def predict_physical(model, norm_stats, x_physical):"""Predict physical summary statistics from physical inputs. Purpose: apply the trained emulator inside an inference model while keeping normalization consistent with training. Inputs: model, saved normalization statistics, and x_physical with components [Q0, a]. Output: y_physical = [f_bound, sigma_v, r_h] in physical units. Common pitfall: using validation/test statistics here instead of saved training statistics changes the forward model silently. """ x_norm = (x_physical - norm_stats.x_mean) / norm_stats.x_std y_norm = model(x_norm) y_physical = y_norm * norm_stats.y_std + norm_stats.y_meanreturn y_physicaldef diagonal_gaussian_loglike(y_obs, y_pred, sigma_vec):"""Compute a diagonal Gaussian log-likelihood in physical units. Purpose: score agreement between observed and emulator-predicted summaries. Inputs: y_obs, y_pred, and sigma_vec in the same physical output units. Output: scalar log-likelihood. Common pitfall: sigma_vec has one entry per output, not one universal number for dimensionless, velocity, and length summaries. """ residual = y_obs - y_predreturn-0.5* jnp.sum( (residual / sigma_vec) **2+ jnp.log(2* jnp.pi * sigma_vec **2) )
Choosing the Likelihood Width
The likelihood requires an uncertainty model. A useful starting point is a diagonal vector,
with one component for each output. These entries must be in the same physical units as \(\mathbf{y}_{\rm obs}\).
For synthetic observations from your simulations, this uncertainty may include:
Emulator error: the network does not perfectly predict simulations.
Simulation stochasticity: different random seeds or initial realizations produce scatter even for fixed \((Q_0, a)\).
Numerical uncertainty: timestep, integration, or finite-precision effects perturb the summary statistic.
Observational uncertainty: only relevant if you compare to real observed cluster summaries.
Practical baseline: Use validation/calibration-set RMSE for each output as a first estimate of \(\boldsymbol{\sigma}_{\rm total}\). Keep the test set untouched for final reporting and parameter-recovery checks after the likelihood width is fixed.
COMPUTING sigma_total
=====================
1. Get emulator predictions on validation/calibration set (normalized)
2. Denormalize to physical units
3. Compare to true validation/calibration values
4. RMSE = sqrt(mean((pred - true)**2)) for each output
5. Use these three RMSE values as an initial sigma_total
NoteUncertainty Calibration
If \(\boldsymbol{\sigma}_{\rm total}\) is too small, posteriors will be overconfident and may miss true values. If too large, posteriors will be vague and uninformative.
Validation/calibration RMSE is a principled starting point. You can check calibration: do about 95% of your validation posteriors contain the true parameter values within their 95% credible intervals?
Using an Ensemble in the Likelihood
If you train an ensemble, you can estimate an input-dependent emulator covariance. For \(M\) ensemble members,
This is still an approximation. It treats ensemble spread as a proxy for emulator uncertainty, so it must be checked with held-out simulations.
TipPredict – Try – Explain: Degeneracy in \((Q_0, a)\)
Predict: Could a larger initial scale radius with lower virial ratio produce a final half-mass radius similar to a more compact but dynamically hotter cluster?
Try: Sketch likelihood contours in \((Q_0, a)\) space for one summary statistic, then sketch how adding \(f_{\rm bound}\) and \(\sigma_v\) might narrow or rotate the contour.
Explain: Multiple initial conditions can produce similar summaries. The posterior may be an elongated ridge, not a round blob. This is why multiple summary statistics and honest uncertainty matter.
Validation: Parameter Recovery
The critical test of your inference pipeline:
PARAMETER RECOVERY TEST
=======================
1. Pick a simulation from your TEST set
- You know the true (Q0, a) that generated it
- You have its summary statistics (f_bound, sigma_v, r_h)
2. Treat summary statistics as "observations"
3. Run inference to get posterior over (Q0, a)
4. Check: Do 95% credible intervals contain true values?
- Compute 2.5th and 97.5th percentiles of posterior samples
- True value should fall within this range
5. Repeat for several test cases
- If ~95% contain truth, your pipeline is well-calibrated
- If fewer, $\boldsymbol{\sigma}_{\rm total}$ may be too small (overconfident)
- If more, $\boldsymbol{\sigma}_{\rm total}$ may be too large (uninformative)
Visualization: Create corner plots showing the joint posterior over \((Q_0, a)\) with true values marked. The true values should fall within the 95% contours.
NoteParameter-Recovery Plot
For each held-out recovery test, make a corner plot for \((Q_0, a)\) with the true values marked. The 68% and 95% contours should be interpreted only after the likelihood width has been calibrated on validation data.
Connecting the Pieces
Here’s how all the components fit together in your final project:
COMPLETE INFERENCE PIPELINE
===========================
┌─────────────────────────────────────────────────────────┐
│ TRAINING PHASE (done before inference) │
├─────────────────────────────────────────────────────────┤
│ 1. Generate N-body training data │
│ 2. Compute and store normalization statistics │
│ 3. Train ensemble of emulators │
│ 4. Estimate sigma_total from validation residuals │
│ and reserve test set for final reporting │
└─────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────┐
│ INFERENCE PHASE │
├─────────────────────────────────────────────────────────┤
│ 1. Define NumPyro model: │
│ - Priors on (Q0, a) │
│ - Emulator as forward model (with normalization!) │
│ - Gaussian likelihood with sigma_total │
│ │
│ 2. Run NUTS: │
│ - Warmup: adapts step size and mass matrix │
│ - Sampling: collects posterior samples │
│ │
│ 3. Analyze posterior: │
│ - Summary statistics (mean, std, credible intervals)│
│ - Corner plots │
│ - Parameter recovery validation │
└─────────────────────────────────────────────────────────┘
TipWhy This Is Fast
Remember the calculation from Section 1.1? Direct inference with N-body simulations can become infeasible because NUTS may need many forward-model calls.
With your emulator:
NUTS can call the forward model many thousands of times during warmup and sampling
Each JIT-compiled emulator call is typically much faster than rerunning a simulation
The actual speedup depends on model size, hardware, vectorization, and sampler settings
The bottleneck usually becomes sampler overhead rather than the forward model. Report measured wall-clock time from your own machine instead of quoting a universal speedup factor.
This is why JIT compilation matters: the @eqx.filter_jit decorator lets JAX compile repeated model calls. It often gives large speedups, but the honest claim is the one you benchmark for your pipeline.
ImportantMinimum Viable Final-Project Workflow
Before you add optional sophistication, make this baseline work:
If these pieces are solid, the project is scientifically defensible. Optional uncertainty methods should strengthen this workflow, not distract from it.
1.13: Synthesis
Priority: Essential
The Complete Workflow
You’ve learned a powerful workflow for modern computational science:
Show code
flowchart TD subgraph Data["1. Generate Training Data"] sims["Run JAX-native Leapfrog simulations\n(Latin Hypercube in Q0, a)"] stats["Extract summary statistics\n(f_bound, sigma_v, r_h)"] sims --> stats end subgraph Prep["2. Data Preparation"] split["Train/validation/test split"] norm["Normalize to zero mean,\nunit variance"] split --> norm end subgraph Arch["3. Design Architecture"] mlp["MLP: 2 -> 64 -> 64 -> 3"] act["ReLU hidden, linear output"] end subgraph Train["4. Train Network"] init["Xavier initialization"] opt["Adam optimizer"] loss["MSE loss"] init --> opt --> loss end subgraph Ens["5. Build Ensemble"] multi["Train 5 models\n(different seeds)"] uncert["Estimate emulator spread\nthen check calibration"] multi --> uncert end subgraph Eval["6. Evaluate"] test["Held-out accuracy"] plots["Predicted vs. true\nand residual diagnostics"] edges["Domain and edge behavior"] end subgraph Infer["7. Inference"] numpyro["NumPyro model:\nprior -> emulator -> physical likelihood"] nuts["NUTS sampling"] post["Posterior over (Q0, a)"] numpyro --> nuts --> post end Data --> Prep --> Arch --> Train --> Ens --> Eval --> Infer
flowchart TD
subgraph Data["1. Generate Training Data"]
sims["Run JAX-native Leapfrog simulations\n(Latin Hypercube in Q0, a)"]
stats["Extract summary statistics\n(f_bound, sigma_v, r_h)"]
sims --> stats
end
subgraph Prep["2. Data Preparation"]
split["Train/validation/test split"]
norm["Normalize to zero mean,\nunit variance"]
split --> norm
end
subgraph Arch["3. Design Architecture"]
mlp["MLP: 2 -> 64 -> 64 -> 3"]
act["ReLU hidden, linear output"]
end
subgraph Train["4. Train Network"]
init["Xavier initialization"]
opt["Adam optimizer"]
loss["MSE loss"]
init --> opt --> loss
end
subgraph Ens["5. Build Ensemble"]
multi["Train 5 models\n(different seeds)"]
uncert["Estimate emulator spread\nthen check calibration"]
multi --> uncert
end
subgraph Eval["6. Evaluate"]
test["Held-out accuracy"]
plots["Predicted vs. true\nand residual diagnostics"]
edges["Domain and edge behavior"]
end
subgraph Infer["7. Inference"]
numpyro["NumPyro model:\nprior -> emulator -> physical likelihood"]
nuts["NUTS sampling"]
post["Posterior over (Q0, a)"]
numpyro --> nuts --> post
end
Data --> Prep --> Arch --> Train --> Ens --> Eval --> Infer
Connections Across the Course
Concept
Where You Learned It
How It Appears Here
Models as compression
Module 1
Emulator compresses input \(\to\) output relationship into weights
MSE can be derived from a Gaussian residual model under specific assumptions
Gradient-based inference
Module 5 (HMC)
Training uses gradients; NUTS uses gradients
Automatic differentiation
The JAX module
Backprop = reverse-mode autodiff (JAX handles it)
JIT compilation
The JAX module
Essential for fast emulator evaluation in inference
Functional programming
The JAX module
Equinox models are pure functions with explicit state
N-body dynamics
Module 3, Project 2, and the final-project JAX rebuild
The simulator map your emulator approximates
What Makes a Good Emulator?
Accuracy: Predictions match simulations in the training region
Test RMSE should be small relative to output variation
Predicted vs. true plots cluster along \(y = x\)
Calibration: Uncertainty estimates are meaningful
About 95% of true values fall within calibrated 95% prediction intervals
Uncertainty increases where training data is sparse
Posterior recovery tests have credible intervals with sensible empirical coverage
Speed: Fast enough for inference
Microseconds per prediction (vs. seconds for simulation)
JIT compilation is essential
Robustness: Sensible behavior at boundaries
No claims outside the training range unless explicitly tested
Ensemble uncertainty increases appropriately
Residual diagnostics do not reveal systematic bias across \(Q_0\), \(a\), or predicted values
Looking Ahead: The Final Project
You’re now ready for the final project. You’ll:
Generate training data: Latin Hypercube sampling in \((Q_0, a)\) space; compute summary statistics from N-body simulations using your JAX-native final-project simulator
Build and train an emulator: MLP in Equinox; train ensemble with Optax
Evaluate accuracy: Test metrics, predicted vs. true plots, uncertainty behavior
Perform inference: NumPyro model with emulator as forward model; NUTS sampling; corner plots showing parameter recovery
This workflow — expensive simulations \(\to\) fast surrogates \(\to\) Bayesian inference — is exactly how frontier research operates. You’re learning it from the inside out.
Final Conceptual Checkpoint
Before starting the final project, ensure you can answer:
ImportantComprehensive Self-Assessment
Architecture & Theory:
What is the role of activation functions? What happens without them?
What does the Universal Approximation Theorem guarantee? What doesn’t it guarantee?
How many parameters does a network with layers 5 \(\to\) 100 \(\to\) 100 \(\to\) 3 have?
Training:
Explain how backpropagation relates to jax.grad and the autodiff you learned in the JAX module.
Why must normalization use only training set statistics?
Your training loss decreases but validation loss increases after epoch 200. What’s happening?
What hyperparameters would you adjust if loss oscillates wildly?
Uncertainty:
Why do different random seeds produce different trained networks?
What type of uncertainty do ensembles capture? What don’t they capture?
Where would you expect ensemble uncertainty to be highest?
Integration:
In your NumPyro model, what role does the emulator play?
How do you choose \(\sigma_{\rm obs}\) for the likelihood?
How would you check if your inference is working correctly?
Practical:
What does eqx.filter_jit do differently from jax.jit?
What is the purpose of jax.vmap(model) in the training step?
Why is the output layer linear (no activation) for your emulator?