Let f: R -> R be an infinitely differentiable function. What is the computational complexity of calculating the first n derivatives of f in Jax? Naive chain rule would suggest that each multiplication gives a factor of 2 increase, hence the nth derivative would require at least 2^n more operations. I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted? Is there a different between the Jax, Tensorflow and Torch implementations?
https://openreview.net/forum?id=SkxEF3FNPH discusses this topic, but doesn t provide a computational complexity.
There's not much you can say in general about computational complexity of Nth derivatives. For example, with a function like
jnp.sin, the Nth derivative isO[1], oscillating between negative and positivesinandcoscalls as N grows. For an order-k polynomial, the Nth derivative isO[0]for N > k. Other functions may have complexity that is linear or polynomial or even exponential withNdepending on the operations they contain.You imagine correctly! One implementation of this idea is the
jax.experimental.jetmodule, which is an experimental transform designed for computing higher-order derivatives efficiently and accurately. It doesn't cover all JAX functions, but it may be complete enough to do what you have in mind.