Technical Guide: From Simulation to Surrogate
COMP 536 | Final Project Support
This guide is the recommended technical lane for the final project. It is more specific than the public assignment page, but it is still guidance rather than a second hidden rubric. When in doubt, the official contract lives on the Final Project Assignment, the syllabus, and the course-wide project rubric.
This guide is meant to work together with the JAX and ML readings in The Learnable Universe:
- JAX Part 2: Core Transformations
- JAX Part 3: Physics Applications
- Machine Learning Overview
- Neural Networks and Scientific Emulation
If you are stuck on implementation details, read those pages before you start inventing a more complicated solution.
The Big Picture
The core problem is computational cost. A direct N-body simulation is rich enough to capture interesting dynamics, but too expensive to evaluate thousands of times inside an inference loop. The final project asks you to take the validated physics and workflow from your earlier N-body work, rebuild that simulator in JAX, verify that it behaves sensibly, and then turn it into a reusable scientific instrument:
- rebuild the Project 2 simulator ideas into a trustworthy JAX-native Leapfrog simulator,
- generate trustworthy simulation outputs,
- train a fast emulator on those outputs,
- evaluate where the emulator is accurate and where it is fragile,
- use the emulator for an inverse problem.
This workflow shows up across modern computational science. Cosmologists emulate power spectra, climate scientists emulate circulation models, and engineers emulate expensive CFD pipelines. In this course, the N-body cluster problem gives you a concrete version of the same idea.
Recommended Scientific Framing
A strong default framing is:
- rebuild your Project 2 N-body simulator in a JAX-native Leapfrog implementation,
- vary the initial virial ratio \(Q_0\) and Plummer scale radius \(a\),
- choose a defensible space-filling design, preferably Latin Hypercube Sampling, over that \((Q_0, a)\) input space,
- evolve a cluster with your JAX N-body code,
- extract summary statistics from the final state,
- train a neural network to map \((Q_0, a)\) to those diagnostics,
- use the emulator to infer which initial conditions produced a held-out outcome.
That framing is attractive because it keeps the input space small enough to visualize while still forcing you to build a full scientific workflow.
Part 0: Build The JAX-Native Leapfrog Simulator
Before you start coding, it is worth revisiting:
Why This Part Must Be Explicit
Pedagogically, this project is much stronger if the simulator is not treated as a black box. The course has spent weeks building numerical-methods judgment, and the final project should force students to use that judgment in a modern JAX setting. If the simulator arrives pre-built, the project collapses into “train a neural network on someone else’s data pipeline,” which is a much thinner synthesis.
For that reason, the expected code path should include a JAX-native N-body simulator using Leapfrog integration that is rebuilt from Project 2’s validated physics and diagnostics. Students should absolutely reuse and adapt their own earlier scientific work rather than start from zero, but the final-project pipeline still needs to expose a genuinely JAX-native implementation, not a thin wrapper around the old one.
What “Student-Built” Means Here
At minimum, students should own:
- the state representation for positions, velocities, and masses,
- the force or acceleration computation,
- the Leapfrog update step,
- the simulation driver used to generate training data,
- the validation checks that justify trusting the simulator.
This does not mean every project has to become a giant performance-optimization exercise or a blind rewrite from scratch. It does mean the scientific core of the simulator should still be theirs, should visibly descend from Project 2, and should live in the repo in readable JAX-native form.
For this course, “JAX-native” should mean more than “I replaced NumPy with jax.numpy.” A convincing JAX-native rebuild usually shows:
- an array-first state representation that is natural for JAX,
- a clean update or step function that passes state explicitly rather than hiding mutation inside object state,
- at least one meaningful use of
jit,vmap, or another JAX workflow tool where it actually helps the project, - a short explanation in the README or report of why this design is better for repeated simulation, data generation, or emulator training.
The goal is not to use every JAX feature. The goal is to show that you understand why JAX changes how you structure scientific code.
What To Validate Before Data Generation
Before generating an emulator dataset, students should show evidence such as:
- a small-\(N\) sanity check where trajectories behave qualitatively as expected,
- bounded energy behavior rather than obvious secular blow-up,
- consistent center-of-mass or momentum behavior when appropriate,
- stable outputs under a timestep that they can defend.
The exact validation suite can vary, but it should be strong enough that later emulator failures are not secretly simulator bugs.
Part 1: Generate Training Data
Choose a Parameter Design That You Can Defend
You need a dataset that covers the input space well enough for the emulator to learn something real. A good starting point is a 2D design over \(Q_0\) and \(a\) with Latin Hypercube Sampling or a similarly defensible space-filling design, rather than naive uniform random draws. The main reason is coverage: with a modest sample budget, random draws can leave obvious holes, while a Latin hypercube gives you a more even spread across each axis.
A practical default is:
- training set: about 80 to 100 simulations,
- validation/calibration set: about 10 to 20 simulations for training choices, likelihood-width estimates, and uncertainty checks,
- held-out test set: about 20 simulations for final reporting and parameter-recovery examples after choices are fixed,
- an initial debug pass with a much smaller sample count before scaling up.
That small debug pass matters. It lets you verify the end-to-end workflow before you spend time generating the full dataset.
If your simulation budget is smaller, keep the roles separate even if the counts change: train on one subset, calibrate and tune on another, and reserve a final held-out set for the result you report.
Minimum Dataset Schema
Save your emulator dataset in a format that another person can reload without reading your whole codebase first. A CSV, Parquet file, or NumPy archive is fine, but the meaning of each field must be clear.
A strong default table has one row per simulation:
| Field | Meaning |
|---|---|
run_id |
Unique identifier for the simulation |
Q0 |
Initial virial ratio used for this run |
a |
Plummer scale radius or chosen length-scale parameter |
seed |
Random seed or realization identifier |
split |
train, validation, calibration, or test |
f_bound |
Bound mass fraction summary statistic |
sigma_v |
Velocity dispersion summary statistic, with units stated in metadata or README |
r_h |
Half-mass radius summary statistic, with units stated in metadata or README |
status |
Optional flag such as ok, failed, or excluded, with a reason recorded elsewhere |
For example, a tiny debug dataset might begin like this:
| run_id | Q0 | a | seed | split | f_bound | sigma_v | r_h | status |
|---|---|---|---|---|---|---|---|---|
run_0001 |
0.70 | 0.80 | 4301 | train |
0.84 | 1.52 | 2.90 | ok |
run_0002 |
1.05 | 1.40 | 4302 | validation |
0.61 | 1.87 | 3.45 | ok |
run_0003 |
1.35 | 1.10 | 4303 | test |
0.48 | 2.10 | 4.20 | ok |
Also save enough metadata to make the table scientifically meaningful: particle count \(N\), timestep, total integration time, softening prescription, IMF or mass model, unit convention, and the code version or commit used to generate the data.
Hold the Right Things Fixed
For a course-scale project, it is reasonable to keep the following fixed while you vary \(Q_0\) and \(a\):
- particle count \(N\),
- IMF choice,
- integrator, which should be Leapfrog for the baseline course lane,
- timestep policy,
- softening prescription.
The important thing is not which values you choose first, but whether you can justify them and keep them consistent across the training run. If you change multiple physical ingredients at once, it becomes much harder to interpret failure modes later.
For the COMP 536 final-project baseline, a fixed timestep is enough if you can defend it with a convergence or energy-behavior check. You do not need to implement adaptive timestepping. Adaptive timestep control is optional stretch work, not part of the baseline course lane.
Summary Statistics Should Be Physically Legible
A useful starting set of outputs is:
- bound mass fraction \(f_{\rm bound}\),
- velocity dispersion \(\sigma_v\),
- half-mass radius \(r_h\).
Those are good teaching choices because each one maps back to a physically meaningful question:
- How much of the system remains bound?
- How energetic is the bound population?
- How spatially concentrated is the cluster?
If you compute different diagnostics, explain why they are the right observables for your version of the problem.
Validation Before Machine Learning
Before you train an emulator, make sure the data pipeline itself is believable. At minimum, confirm that:
- the rebuilt JAX-native Leapfrog simulation runs reproducibly on a small sample,
- the extracted summary statistics stay in physically plausible ranges,
- the expected qualitative trends appear in simple diagnostic plots,
- your saved data can be reloaded without ambiguity.
If the training data is wrong, the emulator will only help you produce wrong answers faster. This is why the simulator validation step is part of the pedagogy, not just part of the plumbing.
What Counts As A Good Project 2 To JAX Transition
A good transition does not mean copying the old simulator line by line into JAX syntax. It means carrying forward the right things:
- the same physical equations,
- the same validation cases,
- the same diagnostic expectations,
- the same reasoning about why Leapfrog is the correct long-run method.
What should change is the implementation style:
- cleaner state representation,
- JAX-native array operations,
- JAX-friendly step functions,
- a code path that can later support emulator training and fast repeated evaluation.
Project 2 To JAX Migration Checklist
Use your old Project 2 code as scientific evidence and design memory, not as a black box. A useful migration often looks like this:
| Project 2 idea | JAX-native final-project version |
|---|---|
| Body objects or mutable lists | Array-first state such as positions, velocities, masses, and time |
| Method that mutates simulation state | Pure step function: new_state = leapfrog_step(state, params) |
| Python loop over bodies for forces | Vectorized pairwise acceleration with jax.numpy broadcasting or vmap |
| Python loop over timesteps | A clear loop first; optionally jax.lax.scan after correctness is established |
| Diagnostic notebook cells | Reusable functions in summary_stats.py or diagnostics.py |
| Manual one-off plots | Reproducible scripts that regenerate validation and report figures |
Do the simplest readable JAX version first. Add jit, vmap, or lax.scan only after the uncompiled version is correct enough to test. Faster wrong code is still wrong.
Part 2: Build the Emulator
Before you start training, review:
- Machine Learning Overview
- Neural Networks and Scientific Emulation
- Equinox documentation for model definitions, PyTrees, filtered transforms, and serialization
- Optax documentation for optimizer setup and update loops
This is the first project where you are expected to use mature scientific Python/JAX libraries instead of building every algorithm from scratch. That is intentional. Equinox, Optax, and NumPyro are professional tools; using them well is part of the learning goal. The glass-box standard still applies: you should understand what role each library plays, what data flows through it, and how you verified the result.
Why a Neural Network Is Reasonable Here
The recommended emulator is a small multilayer perceptron built in JAX with Equinox and optimized with Optax. For a low-dimensional regression problem like this, an MLP is a pragmatic choice:
- flexible enough to learn nonlinear structure,
- cheap to evaluate after training,
- easy to integrate with the JAX ecosystem you have already used in the course.
Normalize Inputs and Outputs
Normalization is not optional polish here. If one feature has a much larger numerical scale than another, gradient-based training becomes harder to interpret and tune. Compute normalization statistics on the training set, reuse them everywhere else, and make the inverse transformation part of your documented prediction path.
Start Simple
A good baseline architecture is small:
- input dimension 2,
- two hidden layers,
- a moderate hidden width,
- output dimension matching your summary statistics.
If that baseline already works, you have learned something useful. There is no prize for making the model larger before the data pipeline and evaluation logic are solid.
Beat One Simple Baseline First
Before you trust a neural emulator, compare it against at least one simple baseline. Good options include:
- predicting the training-set mean,
- a linear regression model,
- a nearest-neighbor or interpolation-style baseline.
This comparison matters because it answers a real scientific question: is the neural network learning meaningful structure, or is it only adding complexity? A baseline that is easy to beat is still useful, because it gives your held-out metrics context.
Part 3: Evaluate the Emulator
Held-Out Accuracy Is the First Gate
Compute straightforward regression metrics on the held-out set, such as MAE and RMSE for each output. Then inspect predicted-versus-true plots. The numerical metrics tell you the scale of the error; the plots tell you whether the model is biased or systematically failing in a particular region.
For this project, that held-out evaluation should include your chosen simple baseline as well as the neural emulator. The exact baseline is up to you, but the comparison should make it easy to see what the neural model is buying you.
Check Behavior Near Edges
Interpolation is much easier than extrapolation. If your design space spans \(Q_0 \in [0.5, 1.5]\) and a chosen range of \(a\), you should explicitly ask whether the model is less reliable near the edges than near the center. Even a one-dimensional slice through parameter space can reveal whether the emulator remains smooth and sane where training coverage thins out.
Add Some Uncertainty Story
A simple and defensible starting point is an ensemble of independently initialized networks. The spread across the ensemble is not a complete uncertainty model, but it is often enough to expose obviously unstable regions and to keep you from treating a single deterministic predictor as omniscient.
Part 4: Use the Emulator for Inference
The Inverse Problem
Once the emulator is fast enough, you can use it as the forward model inside a probabilistic program. The recommended lane is NumPyro with NUTS:
- sample priors on the physical parameters,
- call the emulator to generate predicted observables,
- compare predictions to a held-out synthetic observation through a likelihood model,
- inspect the posterior and ask whether the true parameters are recovered.
Before writing the inference model, work through the NumPyro Getting Started guide. You should be comfortable with numpyro.sample, distributions, NUTS, and MCMC before adapting the emulator as a forward model.
Do Not Skip the Likelihood Story
Students often jump straight from emulator predictions to a sampler without thinking carefully about the uncertainty model. Your likelihood width should reflect something you can defend, such as emulator error on the validation/calibration set or a similarly justified uncertainty budget. Keep the final held-out test cases for reporting and recovery checks after that width is fixed. If the likelihood is unrealistically tight, your posterior becomes falsely confident. If it is unrealistically broad, the posterior becomes uninformative.
Recovery on Known Cases
The cleanest validation is to choose a held-out simulation with known true parameters, treat its summary statistics as synthetic observations, and ask whether your posterior places meaningful mass near the truth. A corner plot with the true values marked is usually the clearest presentation.
Package and Repo Structure
The exact layout is up to you, but a strong final-project repo usually separates concerns clearly:
- source package or modules,
- scripts or commands for data generation, training, and inference,
- tests,
- outputs and figures,
- a top-level
README.mdwith reproduction instructions.
What matters most is that another person can identify the simulator, emulator, and inference entrypoints quickly and reproduce the core results without guessing.
Figures That Usually Belong in the Report
The following figure set is a strong default:
- training-data coverage in parameter space,
- training or validation loss curves,
- predicted-versus-true plots for the held-out set,
- a slice plot showing mean prediction and uncertainty behavior,
- a posterior or recovery plot for the inference stage.
Use captions to explain what scientific claim each figure supports. A figure should not be present only because it is standard.
Common Failure Modes
- Simulator failing? Stop emulator work and validate a tiny small-\(N\) case first.
- Dataset looks weird? Check units, seeds, split labels, and whether summary statistics are finite and physically plausible.
- Emulator will not train? Check train-only normalization, array shapes, and whether a simple baseline works.
- Ensemble uncertainty looks useless? Plot where the training points are; uncertainty cannot rescue poor coverage.
- Inference misses the truth? Check likelihood width and emulator error before changing sampler settings.
The Simulation, the Data Pipeline, and the Inference All Break at Once
This usually means you moved to inference before validating the upstream layers. Work in order:
- Leapfrog simulator sanity,
- summary-statistic sanity,
- emulator accuracy,
- inference.
The Neural Network Will Not Train
Check the basics first:
- normalization,
- NaN in the data,
- inconsistent train/test transforms,
- learning rate that is too large,
- output targets with very different scales.
The Posterior Misses the Truth
That can mean several different things:
- the emulator is inaccurate in that region,
- the likelihood width is too small,
- the summary statistics are not sufficiently informative,
- the simulator and the inference target are not actually matched.
The fix is almost never “run the sampler longer” until you understand which of those is happening.
Recommended Resources
- Final Project Assignment
- Expectations & Sample Repo Contract
- Final Project Launch Worksheet
- Growth Synthesis Guide
- Official Project Rubric & Grading Scheme
- JAX documentation
- JAX
jitdocumentation - JAX
vmapdocumentation - AI Use & Growth Mindset Policy
- Equinox documentation
- Optax documentation
- NumPyro Getting Started
- NumPyro documentation