Technical Guide: From Simulation to Surrogate

COMP 536 | Final Project Support

Author

Dr. Anna Rosen

Published

April 22, 2026

NoteHow to use this page

This guide is the recommended technical lane for the final project. It is more specific than the public assignment page, but it is still guidance rather than a second hidden rubric. When in doubt, the official contract lives on the Final Project Assignment, the syllabus, and the course-wide project rubric.

NotePaired course readings

This guide is meant to work together with the JAX and ML readings in The Learnable Universe:

If you are stuck on implementation details, read those pages before you start inventing a more complicated solution.

The Big Picture

The core problem is computational cost. A direct N-body simulation is rich enough to capture interesting dynamics, but too expensive to evaluate thousands of times inside an inference loop. The final project asks you to take the validated physics and workflow from your earlier N-body work, rebuild that simulator in JAX, verify that it behaves sensibly, and then turn it into a reusable scientific instrument:

  1. rebuild the Project 2 simulator ideas into a trustworthy JAX-native Leapfrog simulator,
  2. generate trustworthy simulation outputs,
  3. train a fast emulator on those outputs,
  4. evaluate where the emulator is accurate and where it is fragile,
  5. use the emulator for an inverse problem.

This workflow shows up across modern computational science. Cosmologists emulate power spectra, climate scientists emulate circulation models, and engineers emulate expensive CFD pipelines. In this course, the N-body cluster problem gives you a concrete version of the same idea.

Part 0: Build The JAX-Native Leapfrog Simulator

Before you start coding, it is worth revisiting:

Why This Part Must Be Explicit

Pedagogically, this project is much stronger if the simulator is not treated as a black box. The course has spent weeks building numerical-methods judgment, and the final project should force students to use that judgment in a modern JAX setting. If the simulator arrives pre-built, the project collapses into “train a neural network on someone else’s data pipeline,” which is a much thinner synthesis.

For that reason, the expected code path should include a JAX-native N-body simulator using Leapfrog integration that is rebuilt from Project 2’s validated physics and diagnostics. Students should absolutely reuse and adapt their own earlier scientific work rather than start from zero, but the final-project pipeline still needs to expose a genuinely JAX-native implementation, not a thin wrapper around the old one.

What “Student-Built” Means Here

At minimum, students should own:

  • the state representation for positions, velocities, and masses,
  • the force or acceleration computation,
  • the Leapfrog update step,
  • the simulation driver used to generate training data,
  • the validation checks that justify trusting the simulator.

This does not mean every project has to become a giant performance-optimization exercise or a blind rewrite from scratch. It does mean the scientific core of the simulator should still be theirs, should visibly descend from Project 2, and should live in the repo in readable JAX-native form.

ImportantWhat counts as JAX-native evidence

For this course, “JAX-native” should mean more than “I replaced NumPy with jax.numpy.” A convincing JAX-native rebuild usually shows:

  • an array-first state representation that is natural for JAX,
  • a clean update or step function that passes state explicitly rather than hiding mutation inside object state,
  • at least one meaningful use of jit, vmap, or another JAX workflow tool where it actually helps the project,
  • a short explanation in the README or report of why this design is better for repeated simulation, data generation, or emulator training.

The goal is not to use every JAX feature. The goal is to show that you understand why JAX changes how you structure scientific code.

What To Validate Before Data Generation

Before generating an emulator dataset, students should show evidence such as:

  • a small-\(N\) sanity check where trajectories behave qualitatively as expected,
  • bounded energy behavior rather than obvious secular blow-up,
  • consistent center-of-mass or momentum behavior when appropriate,
  • stable outputs under a timestep that they can defend.

The exact validation suite can vary, but it should be strong enough that later emulator failures are not secretly simulator bugs.

Part 1: Generate Training Data

Choose a Parameter Design That You Can Defend

You need a dataset that covers the input space well enough for the emulator to learn something real. A good starting point is a 2D design over \(Q_0\) and \(a\) with Latin Hypercube Sampling rather than naive uniform random draws. The main reason is coverage: with a modest sample budget, random draws can leave obvious holes, while a Latin hypercube gives you a more even spread across each axis.

A practical default is:

  • training set: about 80 to 100 simulations,
  • held-out test set: about 20 simulations,
  • an initial debug pass with a much smaller sample count before scaling up.

That small debug pass matters. It lets you verify the end-to-end workflow before you spend time generating the full dataset.

Hold the Right Things Fixed

For a course-scale project, it is reasonable to keep the following fixed while you vary \(Q_0\) and \(a\):

  • particle count \(N\),
  • IMF choice,
  • integrator, which should be Leapfrog for the baseline course lane,
  • timestep policy,
  • softening prescription.

The important thing is not which values you choose first, but whether you can justify them and keep them consistent across the training run. If you change multiple physical ingredients at once, it becomes much harder to interpret failure modes later.

Summary Statistics Should Be Physically Legible

A useful starting set of outputs is:

  • bound mass fraction \(f_{\rm bound}\),
  • velocity dispersion \(\sigma_v\),
  • half-mass radius \(r_h\).

Those are good teaching choices because each one maps back to a physically meaningful question:

  • How much of the system remains bound?
  • How energetic is the bound population?
  • How spatially concentrated is the cluster?

If you compute different diagnostics, explain why they are the right observables for your version of the problem.

Validation Before Machine Learning

Before you train an emulator, make sure the data pipeline itself is believable. At minimum, confirm that:

  • the rebuilt JAX-native Leapfrog simulation runs reproducibly on a small sample,
  • the extracted summary statistics stay in physically plausible ranges,
  • the expected qualitative trends appear in simple diagnostic plots,
  • your saved data can be reloaded without ambiguity.

If the training data is wrong, the emulator will only help you produce wrong answers faster. This is why the simulator validation step is part of the pedagogy, not just part of the plumbing.

What Counts As A Good Project 2 To JAX Transition

A good transition does not mean copying the old simulator line by line into JAX syntax. It means carrying forward the right things:

  • the same physical equations,
  • the same validation cases,
  • the same diagnostic expectations,
  • the same reasoning about why Leapfrog is the correct long-run method.

What should change is the implementation style:

  • cleaner state representation,
  • JAX-native array operations,
  • JAX-friendly step functions,
  • a code path that can later support emulator training and fast repeated evaluation.

Part 2: Build the Emulator

Before you start training, review:

Why a Neural Network Is Reasonable Here

The recommended emulator is a small multilayer perceptron built in JAX with Equinox and optimized with Optax. For a low-dimensional regression problem like this, an MLP is a pragmatic choice:

  • flexible enough to learn nonlinear structure,
  • cheap to evaluate after training,
  • easy to integrate with the JAX ecosystem you have already used in the course.

Normalize Inputs and Outputs

Normalization is not optional polish here. If one feature has a much larger numerical scale than another, gradient-based training becomes harder to interpret and tune. Compute normalization statistics on the training set, reuse them everywhere else, and make the inverse transformation part of your documented prediction path.

Start Simple

A good baseline architecture is small:

  • input dimension 2,
  • two hidden layers,
  • a moderate hidden width,
  • output dimension matching your summary statistics.

If that baseline already works, you have learned something useful. There is no prize for making the model larger before the data pipeline and evaluation logic are solid.

Beat One Simple Baseline First

Before you trust a neural emulator, compare it against at least one simple baseline. Good options include:

  • predicting the training-set mean,
  • a linear regression model,
  • a nearest-neighbor or interpolation-style baseline.

This comparison matters because it answers a real scientific question: is the neural network learning meaningful structure, or is it only adding complexity? A baseline that is easy to beat is still useful, because it gives your held-out metrics context.

Part 3: Evaluate the Emulator

Held-Out Accuracy Is the First Gate

Compute straightforward regression metrics on the held-out set, such as MAE and RMSE for each output. Then inspect predicted-versus-true plots. The numerical metrics tell you the scale of the error; the plots tell you whether the model is biased or systematically failing in a particular region.

For this project, that held-out evaluation should include your chosen simple baseline as well as the neural emulator. The exact baseline is up to you, but the comparison should make it easy to see what the neural model is buying you.

Check Behavior Near Edges

Interpolation is much easier than extrapolation. If your design space spans \(Q_0 \in [0.5, 1.5]\) and a chosen range of \(a\), you should explicitly ask whether the model is less reliable near the edges than near the center. Even a one-dimensional slice through parameter space can reveal whether the emulator remains smooth and sane where training coverage thins out.

Add Some Uncertainty Story

A simple and defensible starting point is an ensemble of independently initialized networks. The spread across the ensemble is not a complete uncertainty model, but it is often enough to expose obviously unstable regions and to keep you from treating a single deterministic predictor as omniscient.

Part 4: Use the Emulator for Inference

The Inverse Problem

Once the emulator is fast enough, you can use it as the forward model inside a probabilistic program. The recommended lane is NumPyro with NUTS:

  • sample priors on the physical parameters,
  • call the emulator to generate predicted observables,
  • compare predictions to a held-out synthetic observation through a likelihood model,
  • inspect the posterior and ask whether the true parameters are recovered.

Do Not Skip the Likelihood Story

Students often jump straight from emulator predictions to a sampler without thinking carefully about the uncertainty model. Your likelihood width should reflect something you can defend, such as emulator error on the held-out set or a similarly justified uncertainty budget. If the likelihood is unrealistically tight, your posterior becomes falsely confident. If it is unrealistically broad, the posterior becomes uninformative.

Recovery on Known Cases

The cleanest validation is to choose a held-out simulation with known true parameters, treat its summary statistics as synthetic observations, and ask whether your posterior places meaningful mass near the truth. A corner plot with the true values marked is usually the clearest presentation.

Package and Repo Structure

The exact layout is up to you, but a strong final-project repo usually separates concerns clearly:

  • source package or modules,
  • scripts or commands for data generation, training, and inference,
  • tests,
  • outputs and figures,
  • a top-level README.md with reproduction instructions.

What matters most is that another person can identify the simulator, emulator, and inference entrypoints quickly and reproduce the core results without guessing.

Figures That Usually Belong in the Report

The following figure set is a strong default:

  1. training-data coverage in parameter space,
  2. training or validation loss curves,
  3. predicted-versus-true plots for the held-out set,
  4. a slice plot showing mean prediction and uncertainty behavior,
  5. a posterior or recovery plot for the inference stage.

Use captions to explain what scientific claim each figure supports. A figure should not be present only because it is standard.

Common Failure Modes

The Simulation, the Data Pipeline, and the Inference All Break at Once

This usually means you moved to inference before validating the upstream layers. Work in order:

  1. Leapfrog simulator sanity,
  2. summary-statistic sanity,
  3. emulator accuracy,
  4. inference.

The Neural Network Will Not Train

Check the basics first:

  • normalization,
  • NaN in the data,
  • inconsistent train/test transforms,
  • learning rate that is too large,
  • output targets with very different scales.

The Posterior Misses the Truth

That can mean several different things:

  • the emulator is inaccurate in that region,
  • the likelihood width is too small,
  • the summary statistics are not sufficiently informative,
  • the simulator and the inference target are not actually matched.

The fix is almost never “run the sampler longer” until you understand which of those is happening.