Considering, it may be worth highlighting that tools like Jax exist as well (https://github.com/google/jax). These have even become an expected integration in some toolkits (e.g., numpyro)
It may not be the most elegant approach, but there's a lot of power in something that "mostly just works and then we can optimize narrowly once we find a problem"
It doesn't make a solution that solves this mess bad, but I do wonder about it being a narrow niche
@tschenkel
Mostly its advantage as far as arrays go is its ability to push things out to an accelerator (GPU) without making code changes. Also its JIT functionality is a good bit faster than using pytorch's (at least anecdotally).
My experience with it is not at all related to ODEs (more things like MCMC) and I have no direct experience with its gradient functionality and only limited with its auto vectorization, so take my experience with a grain of salt.
@maegul @astrojuanlu @programming