JAX also has a tutorial on re-implementing your own JAX.
https://docs.jax.dev/en/latest/autodidax.html