Neural ODE Orbit Propagation with JAX

Let automatic differentiation learn orbital physics from trajectory data.

Undergraduate Orbital Mechanics 4–6 weeks
Last reviewed: March 2026

Overview

Orbit propagation is the problem of predicting a satellite's future position from its current state. Classical approaches use analytical or numerical integration of Newton's equations augmented with perturbation models (J2 oblateness, atmospheric drag, solar radiation pressure). While accurate when the perturbation model is correct, classical propagators struggle when the force environment is poorly characterised — as is the case for debris objects, novel propulsion systems, or satellites in non-standard orbits.

Neural ODEs (introduced by Chen et al., 2018) represent a powerful alternative: instead of specifying the dynamics equations analytically, you parameterise the derivative function with a neural network and learn it from trajectory observations. JAX — Google's high-performance numerical computing library — makes this tractable because its JIT compilation and automatic differentiation allow gradients to flow through ODE integration steps efficiently using the adjoint method. The result is a differentiable physics simulator that can be trained end-to-end on measured data.

This project walks you through implementing a Neural ODE propagator for Earth-orbiting satellites using simulated two-body + J2 trajectories as training data, then tests how well the model extrapolates beyond its training distribution. You will develop deep intuition for automatic differentiation, implicit numerical integration, and the interplay between data-driven and physics-based modelling — a cutting-edge combination increasingly used in spacecraft navigation and space situational awareness.

What You'll Learn

  • Implement basic orbital mechanics (two-body, J2) in JAX using functional programming patterns
  • Build a Neural ODE using Diffrax (JAX ODE solver) and Equinox (neural network library)
  • Train differentiable ODE models using the adjoint sensitivity method
  • Compare Neural ODE propagation accuracy against classical propagators over multi-orbit time horizons
  • Analyse Neural ODE generalisation: where learned dynamics fail outside the training distribution

Step-by-Step Guide

1

Implement classical orbit propagators in JAX

Write two-body and J2-perturbed propagators as pure JAX functions using diffrax.diffeqsolve with a Dormand-Prince integrator. Generate a dataset of 500 simulated LEO trajectories (varying inclination, eccentricity, RAAN) with position/velocity sampled at 60-second intervals over one orbit period.

2

Define the Neural ODE architecture

Using the Equinox library, define an MLP that maps state (position, velocity, time) to its derivative (velocity, acceleration). Wrap this network in a diffrax ODE term so the neural network acts as the right-hand-side function of the orbital equations. Verify that forward integration produces a valid trajectory shape before training.

3

Implement the adjoint training loop

Use jax.grad through the diffrax.diffeqsolve call (adjoint method enabled via diffrax.RecursiveCheckpointAdjoint) to compute gradients of a position MSE loss with respect to network parameters. Implement a training loop with Adam optimiser using Optax. Monitor training loss and integrated trajectory error per epoch.

4

Train and evaluate propagation accuracy

Train the Neural ODE on 400 trajectories and evaluate on a held-out set of 100. Plot position error vs. time horizon for both the Neural ODE and the classical J2 propagator. Investigate whether the Neural ODE successfully recovers J2 effects by inspecting predicted along-track vs. cross-track errors.

5

Stress-test generalisation with drag perturbations

Generate a second test dataset that includes atmospheric drag (a perturbation not seen during training). Evaluate propagation error for the Neural ODE vs. a J2-only classical propagator on these trajectories. Analyse whether the neural model partially captures drag effects or fails gracefully, and quantify the error growth rate.

6

Write a conference-style paper section

Synthesise findings into a 4-page report in the style of an AAS/AIAA Astrodynamics Specialist Conference paper. Include a methodology section explaining the adjoint method without full derivation, a results section with error plots and tables, and a discussion section addressing model limitations, data requirements, and potential for operational orbit determination applications.

Go Further

  • Augment the Neural ODE with a known physics prior: parameterise only the unmodelled acceleration residual (a "physics-residual" model) and show that it trains faster and generalises better than a pure neural approach.
  • Implement a Gaussian process ODE variant and compare uncertainty quantification against the deterministic Neural ODE.
  • Apply the trained propagator inside an unscented Kalman filter for orbit determination from simulated radar measurements.