Neural ODE Orbit Propagation with JAX
Let automatic differentiation learn orbital physics from trajectory data.
Last reviewed: March 2026Overview
Orbit propagation is the problem of predicting a satellite's future position from its current state. Classical approaches use analytical or numerical integration of Newton's equations augmented with perturbation models (J2 oblateness, atmospheric drag, solar radiation pressure). While accurate when the perturbation model is correct, classical propagators struggle when the force environment is poorly characterised — as is the case for debris objects, novel propulsion systems, or satellites in non-standard orbits.
Neural ODEs (introduced by Chen et al., 2018) represent a powerful alternative: instead of specifying the dynamics equations analytically, you parameterise the derivative function with a neural network and learn it from trajectory observations. JAX — Google's high-performance numerical computing library — makes this tractable because its JIT compilation and automatic differentiation allow gradients to flow through ODE integration steps efficiently using the adjoint method. The result is a differentiable physics simulator that can be trained end-to-end on measured data.
This project walks you through implementing a Neural ODE propagator for Earth-orbiting satellites using simulated two-body + J2 trajectories as training data, then tests how well the model extrapolates beyond its training distribution. You will develop deep intuition for automatic differentiation, implicit numerical integration, and the interplay between data-driven and physics-based modelling — a cutting-edge combination increasingly used in spacecraft navigation and space situational awareness.
What You'll Learn
- ✓ Implement basic orbital mechanics (two-body, J2) in JAX using functional programming patterns
- ✓ Build a Neural ODE using Diffrax (JAX ODE solver) and Equinox (neural network library)
- ✓ Train differentiable ODE models using the adjoint sensitivity method
- ✓ Compare Neural ODE propagation accuracy against classical propagators over multi-orbit time horizons
- ✓ Analyse Neural ODE generalisation: where learned dynamics fail outside the training distribution
Step-by-Step Guide
Implement classical orbit propagators in JAX
Write two-body and J2-perturbed propagators as pure JAX functions using diffrax.diffeqsolve with a Dormand-Prince integrator. Generate a dataset of 500 simulated LEO trajectories (varying inclination, eccentricity, RAAN) with position/velocity sampled at 60-second intervals over one orbit period.
Define the Neural ODE architecture
Using the Equinox library, define an MLP that maps state (position, velocity, time) to its derivative (velocity, acceleration). Wrap this network in a diffrax ODE term so the neural network acts as the right-hand-side function of the orbital equations. Verify that forward integration produces a valid trajectory shape before training.
Implement the adjoint training loop
Use jax.grad through the diffrax.diffeqsolve call (adjoint method enabled via diffrax.RecursiveCheckpointAdjoint) to compute gradients of a position MSE loss with respect to network parameters. Implement a training loop with Adam optimiser using Optax. Monitor training loss and integrated trajectory error per epoch.
Train and evaluate propagation accuracy
Train the Neural ODE on 400 trajectories and evaluate on a held-out set of 100. Plot position error vs. time horizon for both the Neural ODE and the classical J2 propagator. Investigate whether the Neural ODE successfully recovers J2 effects by inspecting predicted along-track vs. cross-track errors.
Stress-test generalisation with drag perturbations
Generate a second test dataset that includes atmospheric drag (a perturbation not seen during training). Evaluate propagation error for the Neural ODE vs. a J2-only classical propagator on these trajectories. Analyse whether the neural model partially captures drag effects or fails gracefully, and quantify the error growth rate.
Write a conference-style paper section
Synthesise findings into a 4-page report in the style of an AAS/AIAA Astrodynamics Specialist Conference paper. Include a methodology section explaining the adjoint method without full derivation, a results section with error plots and tables, and a discussion section addressing model limitations, data requirements, and potential for operational orbit determination applications.
Career Connection
See how this project connects to real aerospace careers.
Space Operations →
Satellite operators need accurate propagators for conjunction analysis and manoeuvre planning; neural propagators are an emerging tool in this domain.
Aerospace Engineer →
Spacecraft GNC engineers increasingly combine physics models with machine learning — Neural ODE skills are a direct entry point to this hybrid modelling paradigm.
Astronaut →
Astronauts working on science missions benefit from understanding how ML tools augment classical orbital mechanics for trajectory planning.
Drone & UAV Ops →
The Neural ODE approach to learning dynamics from data applies equally to UAV trajectory modelling in complex wind environments.
Go Further
- Augment the Neural ODE with a known physics prior: parameterise only the unmodelled acceleration residual (a "physics-residual" model) and show that it trains faster and generalises better than a pure neural approach.
- Implement a Gaussian process ODE variant and compare uncertainty quantification against the deterministic Neural ODE.
- Apply the trained propagator inside an unscented Kalman filter for orbit determination from simulated radar measurements.