graph LR
A["Your Function
f(x)"] --> B["grad f
Differentiate"]
A --> C["jit f
Compile"]
A --> D["vmap f
Vectorize"]
A --> E["pmap f
Parallelize"]
B --> F["Compose!
jit(grad(f))"]
C --> F
D --> F
E --> F
F --> G["Fast, batched,
differentiable,
parallel code"]
Overview: Computing the Universe with JAX
The Learnable Universe | Module 1 | COMP 536
The Crisis That Changed Everything
Mountain View, California. 2017-2018.
Researchers at Google Brain faced a familiar research bottleneck. They wanted the ease of NumPy for prototyping ML experiments, but needed automatic differentiation and hardware acceleration on Google’s custom TPUs. Existing frameworks forced an uncomfortable choice.
TensorFlow required pre-defining static computational graphs. Experimentation meant rebuilding graphs repeatedly. PyTorch offered flexibility but didn’t efficiently target specialized hardware like TPUs. And manually distributing large computations across multiple accelerators? A complex engineering nightmare for each new algorithm.
The team’s insight: What if, instead of one monolithic framework, they built a library of composable function transformations?
Write ordinary Python and NumPy code. Then apply transformations as you need them:
gradfor automatic differentiationjitfor Just-In-Time compilation to XLA (Accelerated Linear Algebra)vmapfor automatic vectorizationpmapfor parallelization across devicesCompose them like Lego blocks:
jit(vmap(grad(f)))just works.They presented this approach at SysML 2018 and released it as JAX later that year. The key innovation: composable transformations that separated concerns — automatic differentiation (
grad()), just-in-time compilation (jit), and parallelization (pmap) became independent operations you could mix and match freely.What Google built for ML research would accidentally revolutionize scientific computing. The same tools that enabled large-scale neural network training would solve one of computational physics’ deepest bottlenecks: the astronomical cost of gradients.
In Project 4, you spent hours implementing Hamiltonian Monte Carlo. You hand-coded finite difference gradients, watching your computer evaluate the log-posterior thousands of times just to approximate derivatives. It worked. But deep down, you knew: There has to be a better way.
In Project 2, you wrote an N-body simulator. It took minutes to run 100 particles for 100 timesteps. You accepted this as “the cost of doing physics.”
You stand at a crossroads.
Down one path: continue with NumPy, finite differences, CPU-bound loops. The well-trodden path of academic computational physics. It works. Down the other path: automatic differentiation (“autodiff”), just-in-time (JIT) compilation, GPU acceleration. The path that enables previously impractical science. Physics-informed machine learning. Instant parameter inference from surrogate models. Real-time exploration of billion-dimensional parameter spaces.
This module is that crossroads.
By the time you finish this module and the JAX-native rebuild phase of the final project, your N-body code should run much faster than your original Project 2 version, support numerically exact gradients where appropriate, and generate simulation ensembles large enough to train and evaluate surrogate models.
Learning Outcomes
By the end of this overview, you will be able to:
Learning Paths & Priorities
This overview supports three learning paths. Choose based on your background and time constraints:
For students with ML experience or time pressure
Focus on sections marked 🔴 Essential:
- What This Module Is About
- The Pain Point
- A Taste of JAX
- Looking Ahead
Outcome: Understand JAX’s core value proposition and what you’ll build. Proceed directly to Part 1.
For most students — balances context with efficiency
Add sections marked 🟡 Important:
- All Essential sections above
- The Computational Landscape
- Google’s JAX Origin Story
- The Three Computational Eras
Outcome: Full context for WHY JAX matters, HOW it came to exist, and WHERE it fits in scientific computing history.
For students interested in career context or research trajectory
Include all sections:
- All Essential and Important sections
- Career Context (dropdown admonition)
- All conceptual checkpoints and reflection prompts
Outcome: Deep understanding of JAX’s role in the academic-industry landscape, career implications, and full historical context.
This is NOT Part 1. You don’t need to code anything here. This is pure motivation and context.
Goals:
- Connect to your lived experience (Project 4 pain, Project 2 performance limits)
- Show you what’s possible (automatic differentiation, substantial speedups, batched simulation ensembles)
- Explain why JAX exists and why you need to learn it now
- Preview the transformation you’ll undergo (script writer \(\to\) software engineer)
After this overview: Proceed to Part 1 for conceptual foundations (functional programming, pure functions, computational graphs).
What This Module Is About
Priority: 🔴 Essential
This module transforms you from script writers to scientific software engineers. You’ll learn to:
- Compute gradients automatically and exactly (no more finite differences!)
- Speed up code by 10-100\(\times\) through JIT compilation
- Batch thousands of computations effortlessly via vectorization
- Build professional, reusable scientific software (not just scripts that “work”)
- Generate reproducible training data for machine learning and emulation
Automatic differentiation: Computing numerically exact derivatives (machine precision) via computational graphs and the chain rule, not finite-difference approximations.
JIT compilation: Just-In-Time compilation to optimized machine code via Google’s XLA compiler.
Vectorization: Automatic batching of operations across data, enabling massive parallelism.
This isn’t just learning a new library. It’s learning how modern scientific software is built.
By the end of Module 1 and the early final-project JAX rebuild, you’ll have transformed your N-body simulator from a NumPy script into a modern, reproducible JAX workflow that:
- Runs substantially faster than the original NumPy version
- Computes forces via automatic differentiation or carefully validated analytical expressions
- Batches simulation ensembles for downstream emulation work
- Is tested, documented, and reproducible enough to support a public-facing final project
Where you’ve been:
- Projects 1-3: Forward simulation (Monte Carlo stellar populations, N-body dynamics, radiative transfer)
- Module 5 & Project 4: Inverse problems (Bayesian inference, HMC for cosmological parameters)
Where you’re going:
- Module 1 + the final-project JAX rebuild: Learn JAX, rebuild your Project 2 simulator, validate it, and generate training data
- Module 2: Machine Learning: Train surrogate models on simulation outputs
- Final Project: Physics-informed learning (combine simulation + ML for instant inference)
Module 1 is the inflection point — where you cross from Era 2 (expensive inference) to Era 3 (amortized inference via learned surrogates).
Google’s JAX — Why Does This Library Exist?
Priority: 🟡 Important
The Problem Google Faced (2015-2017)
Google had built custom AI accelerators called TPUs (Tensor Processing Units) for machine learning. These chips could run computations 10-100\(\times\) faster than GPUs — but only if software could exploit them properly.
The problem:
- TensorFlow (Google’s first ML framework): Required static computational graphs. You couldn’t write flexible, exploratory Python code — you had to pre-define everything before running.
- PyTorch (Facebook’s framework): Flexible (dynamic graphs) but didn’t compile efficiently to specialized hardware.
- CUDA (NVIDIA’s low-level GPU programming): Full control but required hardware expertise most researchers didn’t have.
What researchers actually wanted: Write ordinary Python code describing physics or ML models, and automatically get:
- Gradients (for optimization/inference)
- Compilation (for speed)
- Hardware acceleration (TPU/GPU support)
- Parallelization (across multiple devices)
Without needing to be hardware experts or ML engineers.
The Innovation: JAX’s Key Insight
JAX’s breakthrough was complete separation of concerns:
You write: Pure Python functions (just describe your computation — your physics)
JAX provides: Independent transformations you compose:
The magic: These transformations compose freely:
# All of these are valid!
jit(grad(f)) # Fast gradients
vmap(grad(f)) # Batched gradients
jit(vmap(grad(f))) # Fast batched gradients
pmap(jit(vmap(grad(f)))) # Multi-GPU fast batched gradientsComposability: The property that transformations can be combined in any order, with the result being predictable and well-defined.
Separation of concerns: Design principle where different aspects (differentiation, compilation, parallelization) are independent and don’t interfere with each other.
Why this was novel: Previous frameworks entangled differentiation with execution. You couldn’t easily say “I want gradients of this function, but also compile it and run it on GPU.”
JAX decouples everything — you describe transformations independently, compose them like Lego blocks.
Why It Matters for You (Not Just Google)
Google didn’t build JAX for speed alone — they built it because ambitious research questions required automatic differentiation through arbitrary custom code:
- Physics simulators (like your N-body code)
- Domain-specific algorithms (stellar evolution, radiative transfer, hydrodynamics)
- Novel ML architectures (combining data-driven models + physics constraints)
The key insight: If you can differentiate through ANY computation, new science becomes possible.
In Project 2, you hand-coded force calculations:
# You manually derived: F = -G * m1 * m2 / r^2 * r_hat
for i in range(n):
for j in range(i+1, n):
r_vec = positions[j] - positions[i]
r = np.linalg.norm(r_vec)
force_mag = G * masses[i] * masses[j] / r**2
force_vec = force_mag * r_vec / r # Unit vector
forces[i] += force_vec
forces[j] -= force_vec # Newton's 3rd lawWith JAX autodiff, you can instead:
# Define potential energy (easier to get right)
def U(positions, masses):
return -sum(G * m_i * m_j / r_ij for all pairs)
# Get forces automatically: F = -∇U
forces = -jax.grad(U, argnums=0)(positions, masses)Why this matters:
- Fewer bugs (potential energy is scalar, simpler than vector forces)
- Generalizes easily (add magnetic fields, relativity, exotic potentials)
- Numerically exact derivatives (machine precision, not hand-coded approximations that might have errors)
This is the power JAX unlocks.
Current Reality (2025)
Frontier research labs (DeepMind, Anthropic, Google Research, national labs) now use JAX as production infrastructure.
You’re learning tools that enable cutting-edge research, not pedagogical simplifications.
JAX wasn’t built for astrophysics — but it solves our problems perfectly.
Key insights about JAX’s origin:
- JAX emerged from hardware constraints (TPUs needed better software)
- Composable transformations are the innovation (not just “another array library”)
- Separation of concerns enables flexibility (grad, jit, vmap work independently and compose)
- Autodiff through arbitrary code unlocks new science (not just faster old science)
- Production infrastructure, not toy (DeepMind, Anthropic use this for frontier research)
Why this matters: Understanding JAX’s design philosophy helps you use it effectively — you’ll know WHY functional programming is required, WHY transformations compose, WHY it’s worth the learning curve.
The rare combination you’re building:
Most astronomy and physics masters and PhD students graduate with either (1) domain expertise (astrophysics, statistical mechanics) or (2) modern ML infrastructure (PyTorch, production systems). You’re learning both — plus professional software engineering (testing, packaging, documentation).
Where this combination is explicitly hiring (2025):
- Research labs: DeepMind, Anthropic, OpenAI (physics-informed ML, differentiable simulators)
- Scientific ML: Climate modeling, materials science, drug discovery (simulation + learning)
- Computational astrophysics: Research-intensive universities (modern methods attract top journals)
What makes you different: By the Final Project, you’ll have a portfolio-quality scientific software workflow — a JAX-native N-body simulator, careful validation, reproducible commands, and a simulation-to-emulator pipeline. Most graduates won’t build something this integrated until much later, if ever.
The signal: “I built differentiable simulators that generate training data 100\(\times\) faster” opens doors that “I trained models on existing data” doesn’t. You’re learning to create the data that powers modern science, not just analyze it.
JAX specifically: Growing rapidly in research (not production ML). Complements PyTorch (which you can learn easily once you know JAX). Positions you at the intersection of physics + ML where the most interesting problems live.
The Pain Point — Your Project 4 Experience
Priority: 🔴 Essential
Remember Project 4?
In Project 4, you implemented Hamiltonian Monte Carlo to measure dark energy parameters. For each HMC step, you needed gradients of the log-posterior, which you computed using finite differences — approximating derivatives by evaluating the function at slightly different parameter values.
This worked. You got results. Dark energy parameters estimated. Project submitted.
But the pain was real:
- Scaling nightmare: Each gradient required 2d function evaluations (d = number of parameters)
- d=2 (your Project 4): 200,000 likelihood evaluations for one HMC run — tolerable
- d=100 (realistic astronomy problems): 10 million evaluations — completely impractical
- Numerical fragility: Tuning the step size \(h\) was finicky
- Too large: inaccurate derivatives (\(\mathcal{O}(h^2)\) error)
- Too small: floating-point precision errors
- Optimal \(h\) varies by parameter — manual tuning required
- Approximate, not exact: Gradient errors compound through thousands of HMC steps, causing poor mixing and biased estimates
What You’ll Be Able to Do After Module 1
Project 4 (Finite Differences):
# Many lines of careful numerical code
grad = np.zeros_like(theta)
for i in range(len(theta)):
theta_plus, theta_minus = theta.copy(), theta.copy()
theta_plus[i] += h; theta_minus[i] -= h
grad[i] = (log_posterior(theta_plus) - log_posterior(theta_minus)) / (2*h)
# Cost: 2d likelihood evaluations (approximate)After Module 1 (Automatic Differentiation):
grad_log_posterior = jax.grad(log_posterior)
# Cost: ~3-5× forward pass, INDEPENDENT of d (exact!)One line. Not just shorter — exact (to machine precision), fast (400\(\times\) faster for d=1000), and automatic (works for any function JAX can trace).
HMC requires ~50,000 gradient evaluations per run (50 leapfrog steps \(\times\) 1000 proposals). With finite differences, this becomes computationally impossible for d > 10. With autodiff, high-dimensional inference becomes feasible. This is why modern tools (PyMC, Stan, NumPyro) all use autodiff.
A Taste of JAX — Three Transformations That Change Everything
Priority: 🔴 Essential
This is a brief preview of JAX’s core capabilities. Don’t worry about understanding every detail — Parts 1-2 will explain the how and why. For now, see what’s possible.
1. Automatic Differentiation: jax.grad
Problem: In Project 2, you hand-coded gravitational forces with nested loops, manual vector math, and careful Newton’s 3rd law bookkeeping.
JAX solution: Write the simpler function (potential energy), get forces for free:
import jax
import jax.numpy as jnp
def gravitational_potential(positions, masses):
"""Compute potential energy U (scalar, simple)."""
n = len(masses)
eps =0.01 # Softening length
G = 6.67e-8 # Gravitational constant in CGS
U = 0.0
for i in range(n):
for j in range(i+1, n):
r = jnp.linalg.norm(positions[j] - positions[i]) + eps**2.0
U -= G * masses[i] * masses[j] / r
return U
# Automatic differentiation: F = -∇U
force_function = jax.grad(gravitational_potential, argnums=0)What you get: Exact forces via chain rule, automatically. Want to add magnetic fields or modified gravity? Just update U(r) and call jax.grad — forces come free.
Connection to Project 4: This is the same jax.grad you’ll use for HMC gradients — exact, fast, automatic.
2. JIT Compilation: jax.jit — 10-100\(\times\) Speedups
Problem: NumPy loops are slow (Python interpreter overhead). Your Project 2 N-body code took minutes for 100 particles.
JAX solution: Compile Python to optimized machine code:
def compute_forces(positions, masses):
# ... your N-body force calculation ...
n = len(masses)
forces = np.zeros_like(positions)
for i in range(n):
for j in range(n):
if i != j:
r_vec = positions[j] - positions[i]
r = np.linalg.norm(r_vec) + 1e-10
forces[i] += G * masses[i] * masses[j] * r_vec / r**3
return forces
def compute_forces_jax(positions, masses, G=6.67e-8):
"""N-body forces (JAX version—same logic, jnp instead of np)."""
n = len(masses)
forces = jnp.zeros_like(positions)
for i in range(n):
for j in range(n):
if i != j:
r_vec = positions[j] - positions[i]
r = jnp.linalg.norm(r_vec) + 1e-10
forces[i] += G * masses[i] * masses[j] * r_vec / r**3
return forces
# JIT-compile the JAX version (one line!)
compute_forces_jit = jax.jit(compute_forces_jax)Now benchmark:
# Setup: 100 particles
n_particles = 100
positions_np = np.random.rand(n_particles, 3) * 1e13 # Random positions [cm]
masses_np = np.random.rand(n_particles) * 1e33 # Random masses [g]
# NumPy timing
start = time.time()
for _ in range(100):
forces_np = compute_forces_numpy(positions_np, masses_np)
time_numpy = (time.time() - start) / 100
# JAX timing (convert arrays)
positions_jax = jnp.array(positions_np)
masses_jax = jnp.array(masses_np)
# First call: compilation overhead (ignore this one)
_ = compute_forces_jit(positions_jax, masses_jax).block_until_ready()
# Actual timing
start = time.time()
for _ in range(100):
forces_jax = compute_forces_jit(positions_jax, masses_jax).block_until_ready()
time_jax_jit = (time.time() - start) / 100
print(f"NumPy time: {time_numpy*1000:.2f} ms")
print(f"JAX+JIT time: {time_jax_jit*1000:.2f} ms")
print(f"Speedup: {time_numpy / time_jax_jit:.1f}×").block_until_ready(): JAX evaluates lazily — this forces computation to complete before timing. Without it, you’d measure time to launch computation, not time to finish it.
Typical output (your mileage may vary):
NumPy time: 45.2 ms
JAX+JIT time: 1.2 ms
Speedup: 37.7×
What happened?
JAX compiled your Python function to optimized machine code (via XLA compiler). The XLA compiler:
- Fuses operations (one kernel instead of many Python loops)
- Eliminates intermediate arrays (no memory allocation overhead)
- Optimizes for your specific CPU/GPU (uses SIMD, caching, etc.)
One decorator (@jax.jit) gave you ~40\(\times\) speedup. This is typical for numerical code.
Actual speedups depend on:
- Hardware: CPU (10-30\(\times\)), GPU (30-100\(\times\)), TPU (50-200\(\times\))
- Problem size: Bigger problems \(\to\) better speedups (more to optimize)
- Code structure: More operations to fuse \(\to\) better optimization
Typical range: 10-100\(\times\) for scientific computing workloads.
See Part 6, Appendix A for detailed benchmarks with hardware specs, problem sizes, and breakdown of where speedups come from.
Misconception: “JIT just makes Python loops faster.”
Reality: JIT compiles your function to machine code (like C/Fortran). The speedup comes from:
- Operation fusion (eliminate Python interpreter overhead)
- Memory optimization (eliminate intermediate arrays)
- Hardware-specific code (use CPU vector instructions, GPU parallelism)
Implication: You still need to write “JAX-friendly” code (pure functions, no data-dependent control flow). Part 1 teaches you what “JAX-friendly” means.
3. Vectorization: jax.vmap — Batch 1000 Simulations Automatically
Problem: You need a large ensemble of N-body simulations for emulator training and evaluation. A sequential NumPy loop quickly becomes a bottleneck.
JAX solution: Automatic vectorization over batch dimension:
def simulate(initial_conditions):
# ... your N-body integration ...
return final_state
# Automatic batching (no manual loops!)
batched_simulate = jax.vmap(simulate)
results = batched_simulate(initial_conditions_batch) # 1000 sims in parallelImpact for the final project: batched JAX workflows can shrink data-generation time from “too slow to iterate on” to “fast enough to support emulator training within the course timeline.” The exact gain depends on hardware, problem size, and how well you structure the simulation.
The Key: Composable Transformations
JAX’s superpower is that these three transformations compose freely:
# Combine them: fast, batched, automatic gradients
fast_batched_gradients = jax.jit(jax.vmap(jax.grad(loss_fn)))This enables new science — not just faster old science. Workflows that used to be too expensive for a course project become realistic enough to explore, validate, and iterate on.
Parts 1-2 explain the how and why. This was just a taste to build excitement.
The Three Computational Eras
Priority: 🟡 Important
Scientific computing evolved through three complementary approaches. All three coexist today — you’ll use each depending on the problem:
| Era | Question | Example | Status |
|---|---|---|---|
| Era 1: Forward Simulation (1950s+) | “What happens if…?” | Run N-body with these ICs | Still essential foundation |
| Era 2: Inverse Problems (1990s+) | “What parameters fit…?” | MCMC for cosmological params | Standard for inference |
| Era 3: Learning from Sims (2020s+) | “Amortize inference” | Train surrogate, instant predictions | Emerging frontier |
Module 1 is your bridge from Era 2 to Era 3.
Era 1: Forward Simulation
What: Physics equations \(\to\) simulate forward \(\to\) observe results
Your course examples: Projects 1-3 (stellar populations, N-body, radiative transfer)
Limitation: Can’t efficiently answer “what parameters produced this observation?” (requires expensive trial-and-error or grid search)
Era 2: Inverse Problems (Bayesian Inference)
What: Have observations \(\to\) infer parameters via probability (Bayes’ theorem)
Your course example: Project 4 (HMC for dark energy parameters from supernova data)
Key tools: MCMC, HMC, nested sampling — all require many gradient evaluations (this is where JAX’s autodiff saves you)
Limitation: Expensive! Each posterior sample requires running the forward model. For complex simulations (hours each), inference takes weeks.
Era 3: Amortized Inference via ML Surrogates
What: Train ML models on thousands of simulations \(\to\) make instant predictions
The breakthrough: Pay upfront cost once (generate training data with JAX’s speed), then inference is cheap forever
Your course arc: 1. Module 1 + the final-project JAX rebuild: Rebuild your Project 2 simulator in JAX and generate training data 2. Module 2: Train a surrogate model to learn: parameters \(\to\) observables 3. Final Project: Use the surrogate for fast Bayesian inference
Why JAX enables this: It lowers the cost of repeated simulation enough that training, evaluating, and stress-testing an emulator becomes feasible in one semester.
Course Arc Through the Eras
- Projects 1-3: Era 1 (forward simulation — build physics intuition)
- Project 4, Module 5: Era 2 (inverse problems — learn Bayesian inference) \(\to\) Pain point: Finite differences for gradients don’t scale
- Module 1 (JAX): Bridge to Era 3 (tools for fast simulation + gradient computation)
- Module 2 (Machine Learning): Era 3 (train surrogates on simulation ensembles)
- Final Project: Combine all three (physics-informed learning)
Eras 1-3 are not chronological replacements — they’re complementary approaches:
- Era 3 requires Era 1: Generate thousands of training simulations (JAX-accelerated forward modeling)
- Era 3 enhances Era 2: Surrogates become fast likelihoods for MCMC/HMC
- Era 2 validates Era 3: Use full simulations on parameter subsets to verify surrogate accuracy
Your final-project workflow: Generate a validated simulation ensemble (Era 1) \(\to\) train a surrogate (Era 3) \(\to\) use it in Bayesian inference (Era 2) \(\to\) validate against held-out simulations (Era 1)
Looking Ahead — What You’ll Build
Priority: 🔴 Essential
Module 1 Arc
Part 1: Conceptual Foundations (🔴) \(\to\) Functional programming, pure functions, computational graphs
Part 2: Core Transformations (🔴) \(\to\) grad, jit, vmap (mathematics + practice)
Part 3: N-body Migration (🟡) \(\to\) Rewrite Project 2 in JAX with autodiff forces
Part 4: JAX Ecosystem (🟢) \(\to\) Optax, Equinox, Diffrax, NumPyro (enrichment)
Part 5: Professional Software Engineering (🟡) \(\to\) Packaging, testing, documentation
Part 6: Synthesis (🟡) \(\to\) Reflection on transformation (script writer \(\to\) software engineer)
Final Project Preview: A Professional JAX N-body Workflow
By the end of the early final-project build, you’ll have a JAX-native N-body workflow that:
Runs much faster than your original Project 2 version, especially once the hot loops are compiled and batched
Computes forces via autodiff or validated analytical expressions:
def U(positions, masses): # Potential energy (simple!)
return -sum(G * m_i * m_j / r_ij for all pairs)
forces = -jax.grad(U)(positions, masses) # Automatic, exactBatches simulation ensembles for emulator training and evaluation
Generates ML training data for the emulator work in Module 2 and the final project
Is professionally organized (tests, docs, reproducible commands, and a report-ready workflow)
The Transformation
Before Module 1: Script writer (finite differences, slow loops, one-off NumPy code)
After Module 1: Scientific software engineer (autodiff, vectorization, tested workflows)
Paradigm shift:
Approximate methods → Exact transformations
One-off scripts → Reusable software
Black-box tools → Glass-box understanding
This is the bridge to Era 3: Your JAX workflow enables Module 2’s surrogate-modeling work and the final project’s simulation-to-emulation pipeline.
Ready to Begin?
Self-check before Part 1:
If yes \(\to\) Continue to Part 1: Conceptual Foundations
Part 1 teaches WHY (functional programming, pure functions, computational graphs). Part 2 teaches HOW (transformation mechanics, composability, debugging).
Continue to Part 1: Conceptual Foundations \(\to\)