Part 3: Physics Applications
The Learnable Universe | Module 1 | COMP 536
Prerequisites: Part 2: Core Transformations completed
“The scientist does not study nature because it is useful; he studies it because he delights in it, and he delights in it because it is beautiful.” — Henri Poincar'e
“Make it work, make it right, make it fast.” — Kent Beck
Learning Outcomes
By the end of Part 3, you will be able to:
Roadmap: From Tools to Applications
Priority: Essential
Part 2 taught you JAX transformations. Part 3 shows how to apply them to physics.
You can now:
- Compute numerically exact gradients with
grad(machine precision) - Compile functions with
jit - Vectorize operations with
vmap - Compose transformations correctly
Part 3 teaches application patterns:
- N-body systems: Forces, integration, energy conservation
- ODE solving: Stellar structure, chemical networks, differential equations
- Optimization: Parameter fitting, likelihood maximization
- Best practices: Testing, validation, performance benchmarking
This is NOT a cookbook. We show patterns and design choices, not complete implementations. The final project is where you write the code.
Goal: By the end, you’ll know:
- Which parts of your Project 2 code to refactor first
- Where to apply each transformation
- How to validate correctness
- What speedups to expect
3.1: N-body Systems with JAX
Priority: Essential
Architectural Overview
Your Project 2 N-body code has this structure:
# NumPy version (from Project 2)
def simulate_nbody(positions, velocities, masses, dt, n_steps):
"""Integrate N-body system for n_steps."""
trajectory = []
for step in range(n_steps):
# Compute forces (nested loops)
forces = compute_forces(positions, masses)
# Leapfrog integration
velocities += 0.5 * dt * forces / masses
positions += dt * velocities
forces = compute_forces(positions, masses)
velocities += 0.5 * dt * forces / masses
# Store trajectory
trajectory.append(positions.copy())
return np.array(trajectory)
def compute_forces(positions, masses):
"""Compute gravitational forces via nested loops."""
N = len(positions)
forces = np.zeros_like(positions)
for i in range(N):
for j in range(N):
if i != j:
r_vec = positions[j] - positions[i]
r = np.sqrt(np.sum(r_vec**2) + 1e-10)
forces[i] += G * masses[i] * masses[j] * r_vec / r**3
return forcesJAX refactoring strategy (what you’ll do in the final project):
- Replace force loops with vmap \(\to\) 100\(\times\) faster force calculation
- Add jit to integration step \(\to\) 10\(\times\) faster timestep
- Use lax.scan for trajectory \(\to\) Enables JIT of entire simulation
- Optional: Add grad for exact forces \(\to\) Compare to finite differences
We’ll discuss each transformation, not implement it for you.
Step 1: Vectorizing Forces with vmap
Current bottleneck: Nested loops in compute_forces.
Key insight: Force calculation is naturally parallel — each pairwise interaction is independent.
Design choice: Compute all pairwise forces at once, then sum per particle.
Pseudocode structure (you implement in the final project):
def pairwise_force(r_i, r_j, m_i, m_j):
"""Force on particle i from particle j.
This is the atomic operation you'll write.
Pure function: no side effects, deterministic.
"""
r_vec = r_j - r_i
r = jnp.sqrt(jnp.sum(r_vec**2) + 1e-10) # Softening
force = G * m_i * m_j * r_vec / r**3
return force
# Then apply vmap twice (you figure out in_axes):
# - Once to vectorize over j (all particles acting on i)
# - Once to vectorize over i (all particles feeling forces)
compute_forces_vectorized = jax.vmap(...) # You design thisQuestions to answer in the final project:
- What are the correct
in_axesfor double vmap? - How do you exclude self-interaction (i==j)?
- Should you compute full \(N \times N\) force matrix or just upper triangle?
- How do you sum forces correctly?
Expected result: Forces computed in one vectorized operation, no Python loops.
Benchmark target: meaningfully faster than nested loops for moderate \(N\), with the exact gain depending on hardware and implementation details.
Step 2: Compiling Integration with jit
After vectorizing forces, compile the timestep:
@jax.jit
def leapfrog_step(state, dt):
"""One leapfrog timestep.
state = (positions, velocities, masses)
Returns: new_state
Pure function: no I/O, no prints, deterministic.
"""
positions, velocities, masses = state
# Kick-drift-kick (you implement)
forces = compute_forces_vectorized(positions, masses)
velocities = velocities + 0.5 * dt * forces / masses[:, None]
positions = positions + dt * velocities
forces = compute_forces_vectorized(positions, masses)
velocities = velocities + 0.5 * dt * forces / masses[:, None]
return (positions, velocities, masses)Design principles:
- Pure function: No side effects (no
trajectory.append()inside) - Fixed shapes:
positions.shape = (N, 3)never changes - Deterministic: Same inputs \(\to\) same outputs
Why JIT helps: Fuses force calculation + integration into single compiled kernel.
Benchmark target: 10-20\(\times\) faster than NumPy per timestep.
Step 3: Trajectory Integration with lax.scan
Problem: You need to store trajectory over time. Can’t use trajectory.append() inside JIT.
Solution: Use lax.scan to iterate while JIT-compiled.
Pattern (you implement):
from jax import lax
def scan_body(carry, t):
"""Body function for lax.scan.
carry: state that persists (positions, velocities, masses)
t: iteration counter (can use for diagnostics)
Returns: (new_carry, output_to_stack)
"""
state = carry
new_state = leapfrog_step(state, dt)
# What to output? Options:
# - Full state (memory intensive)
# - Just positions (for visualization)
# - Energy + angular momentum (for validation)
# - Downsampled (every 10th step)
output = new_state[0] # Just positions (you decide)
return new_state, output
# Integrate
initial_state = (positions_0, velocities_0, masses)
final_state, trajectory = lax.scan(
scan_body,
initial_state,
jnp.arange(n_steps)
)
# trajectory.shape = (n_steps, N, 3)Key decisions you’ll make:
- Store full trajectory or downsample?
- Track energy/momentum for validation?
- How to handle diagnostics without breaking JIT?
Alternative: If trajectory doesn’t fit in memory, use checkpointing or only store final state.
Step 4: Energy Conservation Validation
Hamiltonian dynamics conserves energy. Your JAX implementation should too.
Energy conservation tests are extremely sensitive to floating-point precision. Before running any N-body simulation, you MUST enable 64-bit precision:
import jax
jax.config.update("jax_enable_x64", True) # Essential!Why: Roundoff errors accumulate over thousands of timesteps. With default float32, even a perfect symplectic integrator will show visible energy drift.
Energy Test Pattern
def total_energy(positions, velocities, masses, G=6.67e-8, eps2=1e20):
"""Kinetic + potential energy.
Args:
positions: (N, 3) array
velocities: (N, 3) array
masses: (N,) array
G: Gravitational constant (CGS: 6.67e-8)
eps2: Softening length squared (e.g., 1e20 cm^2 for typical clusters)
"""
# Kinetic energy: T = (1/2) * sum(m * v^2)
kinetic = 0.5 * jnp.sum(masses[:, None] * velocities**2)
# Potential energy: U = -sum_{i<j} G*m_i*m_j / r_ij (with softening)
potential = gravitational_potential(positions, masses, G=G, eps2=eps2)
return kinetic + potential
# Compute energy at each timestep
energies = jax.vmap(total_energy, in_axes=(0, 0, None, None, None))(
trajectory_pos,
trajectory_vel,
masses,
G,
eps2
)
# Check conservation
initial_energy = energies[0]
energy_drift = jnp.abs(energies - initial_energy) / jnp.abs(initial_energy)
print(f"Initial energy: {initial_energy:.6e} erg")
print(f"Final energy: {energies[-1]:.6e} erg")
print(f"Max relative drift: {jnp.max(energy_drift):.2e}")
print(f"Mean relative drift: {jnp.mean(energy_drift):.2e}")Expected Energy Drift: Precision-Dependent Targets
Energy conservation accuracy depends critically on dtype:
| Configuration | Expected Drift | Notes |
|---|---|---|
| float64 + symplectic + small dt | \(< 10^{-12}\) | Publication quality |
| float64 + symplectic + moderate dt | \(10^{-10}\) to \(10^{-8}\) | Good for most physics |
| float32 + symplectic | \(10^{-6}\) to \(10^{-5}\) | Not suitable for long integrations |
| float32 + Euler | \(> 10^{-3}\) | Unstable, energy grows |
If your energy drift is worse than expected:
- Check dtype first:
print(positions.dtype)\(\to\) should befloat64 - Check timestep: Is \(\Delta t < 0.01 \times T_{\text{dynamical}}\)?
- Check softening: Is \(\epsilon\) consistent between potential and forces?
- Check integrator: Are you using leapfrog or another symplectic method?
Common mistake: Forgetting to enable float64 \(\to\) immediate \(10^{-6}\) floor on accuracy!
Softening Effects on Energy Definition
Softening changes the Hamiltonian. Your energy definition must match your force law.
Example: Plummer softening
def gravitational_potential_softened(positions, masses, G=6.67e-8, eps2=1e20):
"""Potential with Plummer softening: U = -G*m1*m2 / sqrt(r^2 + eps^2)."""
N = len(masses)
U = 0.0
for i in range(N):
for j in range(i+1, N):
r_vec = positions[j] - positions[i]
r2 = jnp.sum(r_vec**2)
r_soft = jnp.sqrt(r2 + eps2) # Softened distance
U -= G * masses[i] * masses[j] / r_soft
return UCritical: If forces use softening \(\epsilon^2\), then potential MUST use same \(\epsilon^2\). Mismatch \(\to\) energy not conserved by construction!
Softening bias: Even with perfect float64 and symplectic integration, softening introduces a systematic shift in total energy (typically < 1% for well-chosen \(\epsilon\)). This is not drift — it’s a change in the Hamiltonian itself.
Debugging Energy Conservation Failures
If energy drift exceeds targets:
- Verify precision:
print(f"positions dtype: {positions.dtype}") # Must be float64
print(f"velocities dtype: {velocities.dtype}") # Must be float64
assert positions.dtype == jnp.float64, "Enable jax_enable_x64!"- Check softening consistency:
# Force calculation
def force_i_from_j(r_i, r_j, m_i, m_j, G, eps2):
r_vec = r_j - r_i
r2 = jnp.sum(r_vec**2) + eps2 # <-- Same eps2 as potential!
r = jnp.sqrt(r2)
return G * m_i * m_j * r_vec / (r * r2)- Visualize drift over time:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(energies / energies[0])
plt.axhline(1.0, color='k', linestyle='--', alpha=0.3)
plt.ylabel('E(t) / E(0)')
plt.xlabel('Timestep')
plt.title('Energy Conservation')
plt.subplot(1, 2, 2)
plt.semilogy(energy_drift)
plt.ylabel('|ΔE| / E₀')
plt.xlabel('Timestep')
plt.title('Relative Energy Drift (log scale)')
plt.tight_layout()
plt.show()- Compare integrators:
# Good: Leapfrog (symplectic, 2nd order)
# Acceptable: Velocity Verlet (symplectic, 2nd order)
# Bad: Euler (not symplectic, 1st order—energy grows!)Why this matters: If energy drifts significantly beyond precision limits:
- Integration scheme may be wrong (not symplectic)
- Forces have bugs (check vectorization, softening)
- Timestep too large (reduce \(\Delta t\))
- Most common: Forgot to enable float64!
JAX’s numerically exact gradients + symplectic integration + float64 precision should conserve energy to \(\sim 10^{-12}\) relative accuracy over thousands of timesteps.
Optional: Computing Forces via Autodiff
Alternative to analytical forces: Compute gradient of potential.
def gravitational_potential(positions, masses):
"""Total potential energy of system."""
# Compute all pairwise potentials, sum
# (You implement with vmap)
return U_total
# Forces are negative gradient
compute_forces_autodiff = jax.grad(gravitational_potential, argnums=0)
# Compare to your analytical forces
forces_analytical = compute_forces_vectorized(positions, masses)
forces_autodiff = -compute_forces_autodiff(positions, masses)
# Should agree to machine precision
difference = jnp.max(jnp.abs(forces_analytical - forces_autodiff))
print(f"Force difference: {difference:.2e}") # Should be ~1e-15When to use autodiff forces:
- Complex potential (not just gravity)
- Want to be sure forces are correct
- Potential is easier to write than forces
When to use analytical forces:
- Simple potential (gravity, Coulomb)
- Want maximum performance (skip potential calculation)
Module 3 taught: Hamiltonian dynamics preserves phase space volume (Liouville’s theorem) and energy (if \(H\) is time-independent).
JAX enables exact Hamiltonian evolution:
- Symplectic integrators (leapfrog) preserve structure numerically
- Autodiff gives exact \(\nabla H\) (no finite-difference errors)
- Energy conservation validates both integrator and forces
Your final-project N-body system is a computational realization of Hamiltonian mechanics from Module 3. Small, bounded energy drift over long integrations is evidence that you’ve implemented it carefully and chosen a sensible timestep.
3.2: ODE Solving with JAX
Priority: Important
When You Need ODE Solvers
Astrophysics is full of coupled ODEs:
- Stellar structure: \(\frac{dM}{dr}, \frac{dP}{dr}, \frac{dT}{dr}, \frac{dL}{dr}\) (Module 2)
- Chemical networks: \(\frac{dn_i}{dt} = \sum_j R_{ij}(n_j, T)\) (reaction rates)
- Orbit determination: \(\frac{d\mathbf{r}}{dt} = \mathbf{v}, \frac{d\mathbf{v}}{dt} = \mathbf{a}(\mathbf{r})\)
- Radiative transfer: \(\frac{dI}{ds} = -\kappa I + j\) (intensity along ray)
JAX ODE solving: Use Diffrax (JAX-native ODE library) or implement custom integrators.
Pattern: Stellar Structure Equations
Module 2 equations (Eddington’s 4 coupled ODEs):
\[ \frac{dM}{dr} = 4\pi r^2 \rho, \quad \frac{dP}{dr} = -\frac{GM\rho}{r^2}, \quad \frac{dT}{dr} = -\frac{3\kappa L}{16\pi a c r^2 T^3}, \quad \frac{dL}{dr} = 4\pi r^2 \rho \epsilon \]
JAX structure (using Diffrax):
import diffrax
def stellar_structure_rhs(r, y, args):
"""Right-hand side of stellar structure ODEs.
y = [M, P, T, L] at radius r
args = (rho_c, opacity_func, epsilon_func)
Returns: dy/dr = [dM/dr, dP/dr, dT/dr, dL/dr]
Pure function: deterministic, no side effects.
"""
M, P, T, L = y
rho_c, opacity_func, epsilon_func = args
# Compute density from P, T (EOS)
rho = equation_of_state(P, T)
# Compute opacity and energy generation
kappa = opacity_func(rho, T)
epsilon = epsilon_func(rho, T)
# Derivatives
dM_dr = 4 * jnp.pi * r**2 * rho
dP_dr = -G * M * rho / r**2
dT_dr = -3 * kappa * L / (16 * jnp.pi * a * c * r**2 * T**3)
dL_dr = 4 * jnp.pi * r**2 * rho * epsilon
return jnp.array([dM_dr, dP_dr, dT_dr, dL_dr])
# Solve from center to surface
solution = diffrax.diffeqsolve(
diffrax.ODETerm(stellar_structure_rhs),
diffrax.Dopri5(), # Adaptive RK45
t0=r_center,
t1=r_surface,
dt0=dr_initial,
y0=jnp.array([M_center, P_center, T_center, L_center]),
args=(rho_c, opacity_func, epsilon_func),
saveat=diffrax.SaveAt(ts=r_grid)
)
# solution.ys = [M(r), P(r), T(r), L(r)] on gridWhy Diffrax:
- Adaptive timestepping (handles stiff equations)
- JIT-compilable (fast repeated solves)
- Autodiff through solutions (sensitivity analysis)
- Many solvers (RK4, Dopri5, implicit methods)
Gradients Through ODE Solutions
Powerful application: Optimize initial conditions or parameters by taking gradients through the solver.
Example: Find central density \(\rho_c\) that produces correct stellar radius.
def stellar_radius(rho_c, opacity_func, epsilon_func):
"""Solve structure equations, return surface radius.
This is differentiable!
"""
solution = diffrax.diffeqsolve(
diffrax.ODETerm(stellar_structure_rhs),
diffrax.Dopri5(),
t0=0.0,
t1=R_max, # Integrate until surface
dt0=1e-3,
y0=initial_conditions(rho_c),
args=(rho_c, opacity_func, epsilon_func),
)
# Extract radius where P=0 (surface)
return find_surface_radius(solution)
# Gradient: how does radius change with central density?
dR_drho_c = jax.grad(stellar_radius)(rho_c, opacity_func, epsilon_func)
# Use in optimization: match observed radius
def loss(rho_c):
R_model = stellar_radius(rho_c, opacity_func, epsilon_func)
R_obs = 6.96e10 # Solar radius in cm
return (R_model - R_obs)**2
# Gradient descent to find correct rho_c
rho_c_optimal = optimize_with_grad(loss, rho_c_init)Why this is powerful: Autodiff gives exact sensitivity without finite differences. Enables efficient root-finding and parameter fitting.
Module 2: You derived Eddington’s equations from hydrostatic equilibrium, energy transport, and nuclear physics.
JAX + Diffrax: You can now solve these equations numerically, differentiate through them, and optimize stellar models.
Example application:
- Given observed \(M_\star\) and \(R_\star\), find \(\rho_c\) and composition that match
- Compute \(\frac{\partial R_\star}{\partial Z}\) (radius sensitivity to metallicity)
- Infer stellar age from surface temperature evolution
This is computational stellar astrophysics in practice.
3.3: Optimization Patterns
Priority: Essential
Gradient-Based Optimization
Common astrophysics problem: Find parameters \(\theta\) that minimize \(\chi^2\) or maximize likelihood.
Examples:
- Fit period-luminosity relation to Cepheids (Project 1)
- Infer cosmological parameters from supernovae (Project 4)
- Optimize neural network weights (the machine-learning module and final project)
JAX advantage: Numerically exact gradients via autodiff (machine precision), fast optimization.
Pattern 1: Least-Squares Fitting
Problem: Fit model \(y = f(x; \theta)\) to data \((x_i, y_i, \sigma_i)\).
Loss function: \(\chi^2 = \sum_i \frac{(y_i - f(x_i; \theta))^2}{\sigma_i^2}\)
JAX implementation:
def model(x, params):
"""Model prediction for single data point."""
# Your model (linear, power-law, etc.)
return prediction
def chi_squared(params, x_data, y_data, sigma_data):
"""Chi-squared loss over all data."""
# Vectorize model over data
predictions = jax.vmap(model, in_axes=(0, None))(x_data, params)
residuals = y_data - predictions
return jnp.sum((residuals / sigma_data)**2)
# Gradient for optimization
grad_chi_squared = jax.jit(jax.grad(chi_squared, argnums=0))
# Gradient descent (simple version)
params = params_init
for step in range(1000):
gradient = grad_chi_squared(params, x_data, y_data, sigma_data)
params = params - learning_rate * gradient
# Or use Optax (JAX optimization library)
import optax
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
for step in range(1000):
loss, gradient = jax.value_and_grad(chi_squared)(params, x_data, y_data, sigma_data)
updates, opt_state = optimizer.update(gradient, opt_state)
params = optax.apply_updates(params, updates)Key points:
- Use
jax.value_and_gradto get loss and gradient in one call (efficient) - Optax provides advanced optimizers (Adam, SGD with momentum, LBFGS)
- JIT the gradient function for speed
Pattern 2: Maximum Likelihood Estimation
Project 4 application: Find \((\Omega_m, \Omega_\Lambda, H_0)\) that maximize \(P(D|\theta)\).
Negative log-likelihood (minimize):
def neg_log_likelihood(theta, data):
"""Negative log-likelihood for cosmological parameters.
theta = [Omega_m, Omega_Lambda, H0]
data = {'z': redshifts, 'mu_obs': observed_moduli, 'sigma_mu': uncertainties}
"""
Omega_m, Omega_Lambda, H0 = theta
# Prior (optional: constrain to physical region)
if Omega_m < 0 or Omega_Lambda < 0 or H0 < 50 or H0 > 100:
return jnp.inf # Outside allowed region
# Predicted distance moduli
mu_pred = distance_modulus_batch(data['z'], H0, Omega_m, Omega_Lambda)
# Gaussian likelihood
residuals = data['mu_obs'] - mu_pred
chi2 = jnp.sum((residuals / data['sigma_mu'])**2)
return 0.5 * chi2 # Negative log-likelihood (up to constant)
# Find maximum likelihood parameters
from scipy.optimize import minimize
result = minimize(
neg_log_likelihood,
x0=[0.3, 0.7, 70.0], # Initial guess
args=(data,),
jac=jax.grad(neg_log_likelihood), # Numerically exact gradient (machine precision)!
method='L-BFGS-B',
bounds=[(0, 1), (0, 2), (50, 100)]
)
theta_MLE = result.x
print(f"Best-fit: Omega_m={theta_MLE[0]:.3f}, Omega_Lambda={theta_MLE[1]:.3f}, H0={theta_MLE[2]:.1f}")Why JAX helps: jax.grad provides numerically exact gradient (machine precision), BFGS converges in 10-50 iterations (vs thousands for gradient-free methods).
Pattern 3: Uncertainty Quantification
After finding best-fit parameters, estimate uncertainties.
Method 1: Hessian approximation (fast, approximate):
# Hessian of negative log-likelihood at MLE
hessian = jax.hessian(neg_log_likelihood)(theta_MLE, data)
# Covariance matrix (inverse Hessian)
covariance = jnp.linalg.inv(hessian)
# Parameter uncertainties
uncertainties = jnp.sqrt(jnp.diag(covariance))
print(f"H0 = {theta_MLE[2]:.1f} ± {uncertainties[2]:.1f} km/s/Mpc")Method 2: MCMC (Project 4, more accurate):
# Use JAX-accelerated HMC (from Part 2)
grad_log_posterior = jax.jit(jax.grad(log_posterior, argnums=0))
def hmc_step(theta, key, data):
# Your HMC implementation using grad_log_posterior
# (Much faster than finite differences!)
return theta_new
# Run chain
chain = run_mcmc(hmc_step, theta_init, data, n_steps=10000)
# Analyze posteriors: mean, std, correlationsJAX advantage: HMC with numerically exact gradients (machine precision) converges 10-100\(\times\) faster than Metropolis-Hastings.
In Project 4, you’ll implement gradient-based MCMC for supernova cosmology.
JAX will accelerate:
- Gradient computation:
jax.grad(log_posterior)replaces finite differences (100\(\times\) faster) - Likelihood evaluation:
jitcompiles distance modulus calculation (10\(\times\) faster) - Vectorization:
vmapover 43 supernovae simultaneously
Combined speedup: 1000\(\times\) faster than pure Python + finite differences.
This transforms MCMC from “run overnight” to “run in 5 minutes” — enabling rapid iteration and exploration.
3.4: Best Practices and Testing
Priority: Important
Code Structure for JAX
Separate pure computation from I/O:
# GOOD: Pure numerical core
@jax.jit
def compute_observables(state, params):
"""Pure function: state → observables."""
energy = total_energy(state, params)
angular_momentum = compute_angular_momentum(state, params)
return {'energy': energy, 'L': angular_momentum}
# I/O wrapper (not JIT-compiled)
def simulate_with_logging(initial_state, params, n_steps):
"""Run simulation with diagnostics."""
trajectory = []
for step in range(n_steps):
state = integration_step(state, params) # JIT-compiled
if step % 100 == 0:
# Diagnostics (Python, not JIT)
obs = compute_observables(state, params)
print(f"Step {step}: E={obs['energy']:.6e}, L={obs['L']:.6e}")
trajectory.append(state)
return trajectoryKey principle: JIT the hot inner loop, keep I/O in Python wrapper.
Validation Strategy
1. Test against NumPy baseline:
# Generate same random data
key = jax.random.PRNGKey(42)
positions = jax.random.normal(key, (N, 3))
masses = jax.random.uniform(key, (N,))
# Compute forces both ways
forces_numpy = compute_forces_numpy(np.array(positions), np.array(masses))
forces_jax = compute_forces_jax(positions, masses)
# Compare
difference = np.max(np.abs(forces_numpy - np.array(forces_jax)))
print(f"Max difference: {difference:.2e}")
assert difference < 1e-10, "Forces don't match!"2. Check conservation laws:
# Energy conservation (Hamiltonian systems)
energies = [total_energy(state) for state in trajectory]
energy_drift = (energies[-1] - energies[0]) / energies[0]
assert abs(energy_drift) < 1e-8, f"Energy drift: {energy_drift:.2e}"
# Angular momentum conservation (isolated system)
L_initial = compute_angular_momentum(trajectory[0])
L_final = compute_angular_momentum(trajectory[-1])
L_drift = jnp.linalg.norm(L_final - L_initial) / jnp.linalg.norm(L_initial)
assert L_drift < 1e-8, f"Angular momentum drift: {L_drift:.2e}"3. Gradient checks (autodiff vs finite differences):
def check_gradients(f, x, h=1e-5):
"""Compare autodiff gradient to finite differences."""
# Autodiff
grad_autodiff = jax.grad(f)(x)
# Finite differences
grad_fd = np.zeros_like(x)
for i in range(len(x)):
x_plus = x.at[i].add(h)
x_minus = x.at[i].add(-h)
grad_fd[i] = (f(x_plus) - f(x_minus)) / (2*h)
# Compare
difference = np.max(np.abs(np.array(grad_autodiff) - grad_fd))
print(f"Gradient difference: {difference:.2e}")
assert difference < 1e-6, "Gradients don't match!"
# Test on gravitational potential
check_gradients(lambda pos: gravitational_potential(pos, masses), positions)Performance Benchmarking
Measure what you’re optimizing:
import time
def benchmark_function(f, *args, n_runs=100, warmup=10):
"""Benchmark JIT-compiled function."""
# Warmup (compilation)
for _ in range(warmup):
_ = f(*args)
# Time execution
start = time.time()
for _ in range(n_runs):
_ = f(*args)
end = time.time()
time_per_call = (end - start) / n_runs
return time_per_call
# Compare NumPy vs JAX
positions = np.random.randn(100, 3)
masses = np.random.rand(100)
time_numpy = benchmark_function(compute_forces_numpy, positions, masses)
time_jax = benchmark_function(compute_forces_jax,
jnp.array(positions),
jnp.array(masses))
speedup = time_numpy / time_jax
print(f"NumPy: {time_numpy*1e3:.2f} ms")
print(f"JAX: {time_jax*1e3:.2f} ms")
print(f"Speedup: {speedup:.1f}×")Expected speedups (from our testing):
- Force calculation (N=100): 50-200\(\times\) on GPU, 10-50\(\times\) on CPU
- Full integration step: 10-30\(\times\) on GPU, 5-15\(\times\) on CPU
- Full trajectory (1000 steps): 100-500\(\times\) on GPU, 20-100\(\times\) on CPU
Your mileage will vary based on hardware, problem size, and code quality.
Common Pitfalls
1. Premature optimization:
# DON'T: Optimize before you have working code
# Write NumPy version first, validate correctness, THEN refactor to JAX
# DO: Make it work, make it right, make it fast2. Over-JIT-ing:
# DON'T: JIT everything including I/O
@jax.jit # BAD: Contains print and file I/O
def simulate_with_output(state):
print(f"State: {state}")
np.savetxt("trajectory.txt", state)
return next_state(state)
# DO: JIT only pure numerical kernels3. Ignoring compilation time:
# DON'T: Call JIT function once with different shapes
for N in [10, 20, 50, 100]: # Recompiles 4 times!
result = jitted_function(jnp.zeros((N, 3)))
# DO: Fix shapes or accept compilation overhead3.5: Preparing for the Final Project JAX Rebuild
Priority: Essential
What You’ll Refactor
Final project task: Rebuild your Project 2 N-body code in JAX-native form, validate it carefully, and turn it into the simulator that will feed your emulator.
Refactoring checklist:
- Forces:
- Integration:
- Optional enhancements:
- Package structure:
Rebuild Strategy
Phase 1: Rebuild and validate (first week)
- Convert force calculation to JAX
- Test against NumPy (should match to 1e-10)
- Benchmark on small system (N=10)
Phase 2: Compile and optimize (first half)
- Add JIT to integration step
- Use lax.scan for trajectory
- Benchmark on larger system (N=100)
Phase 3: Reproducibility and structure (middle of the project)
- Structure the repo so a grader can run it cleanly
- Write tests and documentation
- Add example scripts
- Test installation on fresh environment
Phase 4: Extensions (optional, only after the baseline works)
- Autodiff forces
- Advanced diagnostics
- Performance profiling
- GPU optimization
Expected Outcomes
After this JAX rebuild, you will:
- Have a validated, reproducible N-body simulator that is much stronger than your original Project 2 version
- Understand JAX transformation patterns deeply
- Be able to refactor any numerical code to JAX
- Have experience organizing scientific software for reuse
- Be prepared for Module 2 and the emulator phase of the final project
This is the culmination of your “glass-box” journey: From writing Newton’s equations from scratch (Project 2) \(\to\) understanding computational patterns (Module 1) \(\to\) rebuilding that simulator as professional scientific software for the final project.
Synthesis: From Transformations to Physics
Module 1 Parts 1-3 complete arc:
Part 1: Why functional programming? (Purity enables transformations) Part 2: How do transformations work? (grad, jit, vmap mechanics) Part 3: How do we apply them? (N-body, ODEs, optimization patterns)
Unified by one principle: Pure functions + structured control flow \(\to\) powerful automated transformations
The JAX Mindset
Traditional scientific computing:
- Write loops
- Optimize by hand (vectorize, parallelize, compile)
- Debug performance issues
- Repeat for each new problem
JAX scientific computing:
- Write pure functions (single example)
- Apply transformations (vmap, jit, grad)
- Validate correctness
- Benchmark (usually 10-1000\(\times\) faster automatically)
The shift: From “optimizing code” to “designing pure functions that can be optimized.”
Connections Across Course
Module 1 (Statistics) \(\to\) Module 1 (JAX):
- CLT enabled inference from samples \(\to\) vmap enables inference from batches
- MaxEnt found optimal distributions \(\to\) transformations find optimal code
Module 2 (Stellar Structure) \(\to\) Module 1 (JAX):
- Differential equations describe stars \(\to\) Diffrax solves them
- Boundary conditions constrain solutions \(\to\) grad optimizes parameters
Module 3 (Phase Space) \(\to\) Module 1 (JAX):
- Hamiltonian dynamics conserve energy \(\to\) Symplectic JAX integration does too
- Liouville theorem preserves volume \(\to\) Pure functions preserve determinism
Module 4 (Radiative Transfer) \(\to\) Module 1 (JAX):
- Monte Carlo sampling \(\to\) jit-compiled random walks (next step: GPU acceleration)
- Path tracing \(\to\) vmap over photons
Module 5 (MCMC) \(\to\) Module 1 (JAX):
- Gradient-based sampling (HMC) \(\to\) grad provides numerically exact gradients (machine precision)
- Likelihood evaluation \(\to\) jit + vmap accelerate computation
Module 2 (Machine Learning) \(\leftarrow\) Module 1 (JAX):
- Neural networks are just differentiable functions
- Training = gradient descent with jit(vmap(grad(loss)))
- Same patterns, different applications
Understanding Checklist
Before starting the final project’s JAX rebuild, ensure you can:
If you answered “yes” to all \(\to\) Ready for the final project’s JAX rebuild
If uncertain on 3+ items \(\to\) Review relevant Part 2/3 sections, ask in office hours
Next Steps
Immediate: Start thinking about your final-project JAX rebuild strategy
- What parts of your N-body code are candidates for vmap?
- Where will JIT help most?
- What validation tests do you need?
Early phase: Implement force calculation with vmap, validate Next phase: Add jit + lax.scan, organize the repo, document the workflow
Then: Module 2 uses these same JAX patterns for neural networks
- Forward pass = function composition
- Backward pass = grad (autodiff)
- Training = jit(vmap(grad(loss)))
- Prediction = vmap(forward_pass)
You’re now equipped to build high-performance scientific software. The transformation from “script writer” to “scientific software engineer” is complete.