Glue AD for Full Language Differentiable Programming

By: Christopher Rackauckas

Re-posted from: http://www.stochasticlifestyle.com/glue-ad-for-full-language-differentiable-programming/

No design choice will be the best choice for all possible users. That’s a statement that is provocative but at the same time I think everyone would easily agree with it. But that should make us all question whether it’s a good idea to ever try and make all users happy with one piece of code. Under the differentiable programming mindset we are trying to make all code in the entire programming language be differentiable, but why would we think that a single system with a single set of rules and assumptions would be the best for everyone?

Optimized Algorithms Across Scientific Computing and Machine Learning

Differentiable programming is a subset of modeling where you model with a program where each of the steps are differentiable, for the purpose of being able to find the correct program with parameter fitting using said derivatives. Just like any modeling domain, different problems have different code styles which must be optimized in different ways. Traditional scientific computing code makes use of mutable buffers writing out nonlinear scalar operations and avoid memory allocations in order to keep top performance. On the other hand, many machine learning libraries allocate a ton of temporary arrays due to using out of place matrix multiplications, [which is fine because dense linear algebra costs grow much faster than the costs of the allocations](https://www.stochasticlifestyle.com/when-do-micro-optimizations-matter-in-scientific-computing/). Some need sparsity everywhere, others just need to fuse and build the fastest dense kernels possible. Some algorithms do great on GPUs, while some do not. This intersection between scientific computing and machine learning, i.e. scientific machine learning and other applications of differentiable programing, is too large of a domain for one approach to make anyone happy. And if an AD system is unable to reach top notch performance for a specific subdomain, it’s simply better for the hardcore package author to not use the AD system and instead write their own adjoints.

Even worse is the fact that mathematically there are many cases where you should write your own adjoints, since differentiating through the code is very suboptimal. Any iterative algorithm is of this sort, where the derivative of a nonlinear solve f(x)=0 may use Newton’s method to get f(x*)=0, but the adjoint is only defined at x* with f'(x*), so there’s no need to ever differentiate through Newton’s method. So we should all be writing adjoints! Does this mean that the story of differentiable programming is destroyed? Is it just always better to not do differentiable programming, so any hardcore library writer will ignore it?

Is There A Common Ground?

Instead of falling into that despair, let’s just follow down that road with a positive light. Let’s assume that the best way to do differentiable programming is to write adjoints on library code. Then what’s the purpose of a differentiable programming system? It’s to help your adjoints get written and be useful. It’s just matrix multiplications in machine learning. If the majority of the code is in some optimized kernel, then you don’t need to worry about the performance of many other aspects rest: you just want it to work. With differentiable programming, if 99% of the computation is in the DiffEq/NLsolve/FFTW/etc. adjoints, what we need from a differentiable programming system is something that will get the rest of the adjoint done and be very easy to make correct. The way to facilitate this kind of workflow would be for the differentiable programming system to:

  1. Have very high coverage of the language. Sacrifice some speed if it needs to, that’s okay, because if 99% of the compute time is in my adjoint, then I don’t want that 1% to be hard to develop. It should just work, however it works.
  2. Be easy to debug and profile. Stacktraces should point to real code. Profiles should point to real lines of code. Errors should be legible.
  3. Have a language-wide system for defining adjoints. We can’t have walled gardens if the way to “get good” is to have adjoints for everything: we need everyone to plug in and distribute the work. Not to just the developers of one framework, and not just to the users of one framework, but to every scientific developer in the entire programming language.
  4. Make it easy to swap out AD systems. More constrained systems may be more optimized, and if I don’t want to define an adjoint, at least I can fallback to something that (a) works on my code and (b) matches its assumptions.

Thus what I think we’re looking for is not one differentiable programming system that is the best in all aspects, but instead we’re looking for a differentiable programming system that can glue together everything that’s out there. “Differentiate all of the things, but also tell me how to do things better”. We’re looking for a glue AD.

If that’s where we need to go, how do we get there?

Zygote is surprisingly close to being this perfect glue AD. It’s stacktraces and profiling are fairly good because they point to the pieces generating the backpasses. It just needs some focus on this goal if it wants to obtain it. For (1), it would need to get higher coverage of the language, focusing on its expanse moreso than doing everything as fast as possible. Of course, it should do as well as it can, but for example, if it needs to sacrifice a bit of speed to get full performance in mutability today, that might be a good trade-off if the goal is to be a glue AD. Perfect? No, but if that’s that would give you the coverage to then tell the user that if they need more on a particular piece of code, seek out more. To seek out more performance, users could just have Zygote call ReverseDiff.jl on a function and have that compile the tape (or other specialized AD systems which will be announced more broadly soon), or may want to write a partial adjoint.

So (4) is really the kicker. If I was to hit a slow mutating code today inside of a differential equation, it would probably be something perfect for ModelingToolkit.jl to handle, so the best thing to do is to build hyper-optimized adjoints of that differential equation using ModelingToolkit.jl. At that level, I can symbolically handle it and generate code that a compiler because I can make a lot of extra assumptions, like cos^2(x) + sin^2(x) = 1 in my mathematical context. I can move code around, auto-parallelize it, etc. easily because of the simple static graph I’m working on. Wouldn’t it be a treat to just `@ModelingToolkitAdjoint f` and bingo now it’s using ModelingToolkit on a portion of code? `@ForwardDiffAdjoint f` to tell it that it “you should forward mode mere”. Yota.jl is a great reverse mode project, so `@YotaAdjoint f` and boom that could be more optimized than Zygote on some cases. `@ReverseDiff f` and let it compile the tape and it’ll get fairly optimal on the places where ReverseDiff.jl is applicable.

Julia is the perfect language to develop such a system for because its AST is so nice and constrained for mathematical contexts that all of these AD libraries do not work on a special DSL language like TensorFlow graphs or torch.numpy, but instead work directly on the language itself and its original existing libraries. With ChainRules.jl allowing for adjoint overloads that apply to all AD packages, focusing on these “Glue AD” properties could really open up the playing field, allowing Zygote to be at the center of an expansive differentiable programming world that works everywhere, maybe makes some compromises to do so, but then gives a system for other developers to make assumptions and define easily define adjoints and plug alternative AD systems into the whole game. This is a true mixed mode which incorporates not just forward and reverse, but also different implementations with different performance profiles (and this can be implemented just through ChainRules overloads!). Zygote would then facilitate this playing field with just a solid debugging and profiling experience, along with a very high chance of working on your code on your first try. That plus buy-in by package authors would be a true solution to differentiable programming.

The post Glue AD for Full Language Differentiable Programming appeared first on Stochastic Lifestyle.