JAX

Last reviewed: March 2026 github.com ↗

What It Is

JAX is Google DeepMind's high-performance numerical computing library — think of it as NumPy that runs on GPUs and TPUs with automatic differentiation built in. Released in 2018, JAX has rapidly become the tool of choice for researchers who need to differentiate through complex computations, compile numerical code for hardware accelerators, and vectorize simulations across thousands of parallel instances.

JAX is completely free and open source under the Apache 2.0 license. It runs on CPU, NVIDIA GPU, Google TPU, and Apple Silicon. Unlike TensorFlow and PyTorch, JAX is not primarily a deep learning framework — it's a numerical computing accelerator that happens to be excellent for ML. Its core transformations are jit (just-in-time compilation for speed), grad (automatic differentiation), vmap (automatic vectorization), and pmap (automatic parallelization). Neural network libraries like Flax and Haiku are built on top of JAX for deep learning tasks.

For aerospace, JAX's killer feature is differentiable programming. You can write a physics simulation — orbital mechanics, fluid dynamics, structural analysis — and JAX will automatically compute gradients through the entire computation. This means you can optimize trajectories, shapes, and control policies using gradient descent on the actual physics, not on simplified surrogate models. This capability is transforming computational aerospace engineering.

Aerospace Applications

JAX occupies a specialized but increasingly important niche in aerospace: anywhere you need to differentiate through physics or run massively parallel simulations.

Differentiable Trajectory Optimization

Traditional trajectory optimization uses numerical methods (collocation, shooting methods) that scale poorly with problem complexity. JAX enables gradient-based trajectory optimization by differentiating directly through the orbital dynamics equations. Researchers at Stanford, MIT, and the University of Texas have used JAX to optimize low-thrust spacecraft trajectories, atmospheric entry profiles, and multi-body gravitational maneuvers — problems that previously required days of computation solved in minutes.

Differentiable CFD and Aerodynamic Shape Optimization

JAX-CFD (Google Research) implements differentiable Navier-Stokes solvers — meaning you can compute how changing an airfoil shape affects drag through the entire fluid simulation, not through a surrogate model. This enables:

  • Inverse design: Specify desired aerodynamic performance, and gradient descent finds the geometry that produces it
  • Real-time shape optimization: Iterate on wing profiles, nacelle shapes, and control surface geometries orders of magnitude faster than traditional adjoint methods
  • Multi-objective optimization: Simultaneously optimize for lift, drag, structural weight, and manufacturability using differentiable physics

Model Predictive Control for Autonomous Flight

JAX's ability to JIT-compile dynamics models and differentiate through them makes it ideal for model predictive control (MPC) — computing optimal control inputs by solving an optimization problem at every time step. Research groups have demonstrated JAX-based MPC for quadrotor control, fixed-wing flight, and spacecraft attitude control running at real-time rates on modest hardware.

Massively Parallel Scientific Simulations

JAX's vmap transformation automatically vectorizes computations — meaning you can run 10,000 Monte Carlo simulations of a satellite constellation, a rocket trajectory, or a structural failure mode simultaneously on a single GPU. This makes uncertainty quantification and reliability analysis orders of magnitude faster than serial approaches.

Accelerated Molecular Dynamics for Materials

JAX MD (from Google) enables differentiable molecular dynamics simulations for aerospace materials research — studying high-temperature alloy behavior, ceramic matrix composites, and thermal protection system materials at the atomic scale with GPU acceleration.

Getting Started

High School

JAX is not a beginner tool. It requires comfort with Python, NumPy, linear algebra, and calculus. If you're in high school, start with Python and NumPy first, then learn PyTorch or TensorFlow. JAX will be waiting when you have the mathematical maturity to use it — typically junior year of an engineering or physics undergraduate program.

That said, if you're a mathematically advanced high school student (AP Calculus BC, linear algebra), you can explore JAX's automatic differentiation: write a simple function, use jax.grad to compute its derivative, and compare to your hand-calculated result. It's a powerful way to build intuition about gradients.

Undergraduate

JAX becomes relevant in junior/senior year when you're comfortable with multivariable calculus, linear algebra, differential equations, and numerical methods. Entry points:

  • Numerical methods course: Reimplement ODE solvers (Runge-Kutta) in JAX, then use jax.grad to compute sensitivities — how do initial conditions affect the final state?
  • Orbital mechanics: Write a two-body propagator in JAX, then differentiate through it to optimize a Hohmann transfer
  • Controls: Implement a model predictive controller using JAX's JIT compilation for real-time performance
  • Senior design: Use JAX for gradient-based optimization of a design parameter — airfoil shape, structural thickness distribution, or propulsion cycle parameters

The official JAX documentation at jax.readthedocs.io includes excellent tutorials. Patrick Kidger's "Equinox" library provides a PyTorch-like neural network interface on top of JAX for students transitioning from PyTorch.

Advanced / Graduate

JAX is increasingly the framework of choice for computational physics and engineering research:

  • Differentiable physics: Implement full differentiable simulations (CFD, FEA, molecular dynamics) and optimize through them end-to-end
  • Neural ODEs and dynamical systems: Use Diffrax (JAX-based differential equation solver) for learning dynamics of aerospace systems
  • Bayesian inference: NumPyro (JAX-based probabilistic programming) for uncertainty quantification in aerospace models
  • Large-scale optimization: Use JAX's multi-GPU/TPU support for problems too large for single-machine computation

Who should learn JAX: Students interested in computational physics, optimization, or scientific computing for aerospace. If you want to build better ML models, learn PyTorch. If you want to differentiate through physics simulations and solve optimization problems with gradient methods, learn JAX. Many top aerospace research groups are migrating to JAX for its speed and mathematical elegance.

Career Connection

RoleHow JAX Is UsedTypical EmployersSalary Range
Computational Scientist — AerospaceDevelop differentiable physics simulations for aerodynamic and structural optimization using JAX's autodiff and JIT capabilitiesNASA research centers, Sandia, LLNL, university labs$120K–$180K
GN&C Engineer (Guidance, Navigation, Control)Build optimal trajectory planners and model predictive controllers using JAX-accelerated dynamics modelsSpaceX, Blue Origin, Relativity Space, Rocket Lab$130K–$190K
Research Scientist — Scientific MLDevelop differentiable solvers, neural operators, and physics-informed architectures for aerospace applicationsGoogle DeepMind, NVIDIA Research, MIT, Stanford$150K–$250K
Optimization EngineerGradient-based multidisciplinary design optimization for aircraft and spacecraft using differentiable analysis tools built in JAXBoeing Research, Airbus Innovation, Aurora Flight Sciences$120K–$170K
Astrodynamics AnalystHigh-fidelity trajectory optimization, constellation design, and mission planning using JAX-accelerated propagatorsJPL, Aerospace Corporation, AGI (Ansys), Kayhan Space$110K–$160K
Verified March 2026