Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia

By: Christopher Rackauckas

Re-posted from: http://www.stochasticlifestyle.com/engineering-trade-offs-in-automatic-differentiation-from-tensorflow-and-pytorch-to-jax-and-julia/

To understand the differences between automatic differentiation libraries, let’s talk about the engineering trade-offs that were made. I would personally say that none of these libraries are “better” than another, they simply all make engineering trade-offs based on the domains and use cases they were aiming to satisfy. The easiest way to describe these trade-offs is to follow the evolution and see how each new library tweaked the trade-offs made of the previous.

Early TensorFlow used a graph building system, i.e. it required users to essentially define variables in a specific graph language separate from the host language. You had to define “TensorFlow variables” and “TensorFlow ops”, and the AD would then be performed on this static graph. Control flow constructs were limited to the constructs that could be represented statically. For example, an `ifelse` function statement is very different from a conditional `if` then `else` of code because `ifelse` would semantically be the same as always calling both branches and then choosing the result, thus only having a single code path (though I say semantically because further compiler optimizations may and usually do reduce that). This static sublanguage is then represented in an intermediate representation (IR) known as XLA which then performed a lot of simplification of linear algebra, and AD was done using the simple graph representation algorithms given that there was no true control flow at this representation. While this gives a lot of efficiency (XLA is great for simplification because it can easily see the whole world), it of course had some major downsides in terms of flexibility and convenience.

Thus you can almost think of this as a source code transformation because all of the autodiff is done on essentially an IR for a language which is not the same as the host language, but for the most part it was requiring the user does the translation to the new language for the AD system which is… rather inconvenient.

PyTorch came along to solve the flexibility and convenience issues by instead using a tape-based method. It generates the code to autodiff every time you run the forward pass by simply storing the operation that it sees in a given forward pass, and then differentiates that set of operations in reverse. This “building of the tape” is done by operator overloading as part of the Tensor type PyTorch says you need to use. How it works is easy to see. For example, f(2.0) would take the first branch of the if statement and then run the while loop 5 times. So then the AD pass would take that set of operations and start running backpropagation through 5 passes of the while loop and back through the first branch. Notice that by using this form, the AD does not “see” any dynamic control flow: that was all in Python, but not in the tape. Thus the AD does not have to handle dynamic control flow, and this makes it very easy to handle a lot of odd cases of the language. The downside to this approach though is that the AD is “per value”, i.e. you cannot do a lot of optimizations on the backwards passes because you will not necessarily ever see the same backwards pass again, and this allows for a lot less optimization.

Does this harm PyTorch’s efficiency beyond repair? Well, no and yes. No it does not harm efficiency in the sense of, most machine learning algorithms are so heavily reliant on expensive kernels, such as matrix multiplication (`A*x`), `conv`, etc., so the amount of work per operation is extremely high in most ML applications that it hides the overhead of this approach. This allows the PyTorch team to spend most of its time optimizing the 2,000+ operators that it provides, and so most people in ML see PyTorch as fast because it comes with fast kernels (fast conv calls, fast GPU linear algebra) despite the AD overhead. That said, you can very easily run into cases where AD and Python interpreter overhead are not washed out. Cases of that are where your arrays are small or where a lot of scalar operations are happening, for example the Julia vs PyTorch Neural ODE benchmarks on cases matching scientific model discovery workflows you see a 100x performance improvement in Julia (even major differences without AD in the ODE and SDE solvers), and can mostly be attributed to language and AD overhead due to the small kernels used in these cases. For this reason the PyTorch team has been working on things like `torch.@jit` as a separate sublanguage that can compile and optimize differently from the rest of the code, specific to handling these cases, though there’s a lot of discussion of the long-term viability of that approach. But anyways, PyTorch has done really well because it made good choices for its domain of use.

So then TensorFlow Eager (2.0) comes around as adds dynamic control flow support in a manner similar to PyTorch as a sad attempt to get everyone back, but of course then it doesn’t play nicely will all of the XLA tooling (because it cannot see the whole graph of all possible operations for all input values to optimize it well) so it didn’t hit the TensorFlow speeds everyone was expecting, so it was kind of the worst of all worlds.

Subsequent tools then all sought ways to either expand the domains of these ideas or try to mix some of the advantages of the two sides. Jax is one of those. Jax uses non-standard interpretation to build a copy of the full code in its own IR to then perform AD on, finally lowering it to TensorFlow’s XLA for optimizations. Jax’s non-standard interpretation is kind of like operator overloading in that it has special objects walk through code in order to build out the exprs (this is called the “tracing” step). But wait, how is it able to trace the full code if there’s dynamic control flow, won’t it have the same issues as PyTorch that it only sees parts of the full code’s potential paths? Indeed that is true, and that’s why it doesn’t want you to use full dynamic control flow and instead use Jax primitives like lax.while, which are function calls that can be caught during tracing to avoid the code having true dynamic behavior at trace time. Also, for this to be true you need that what your function does can be completely determined by its inputs, i.e. the functions must be “pure”. For this reason Jax requires programming in a functional style with pure functions rather than the object-oriented standard of Python, thus a notable trade-off of the abstract interpretation approach. But what you essentially get is a more natural graph builder for TensorFlow, because at the end of the day it ends up in TensorFlow’s XLA IR, and so you get the same efficiency there but in a form that can look and feel a lot more natural. The downside of course is that you still don’t have true dynamism which is why those linked primitives exist, and why they are not well optimized as described in Jax – The Sharp Bits. However, “most” ML algorithms don’t use very much dynamism (example: recurrent neural networks know how many layers they have, they don’t have a while loop iterate to tolerance), and so “most” algorithms tend to do well in this sublanguage. In that sense, it can optimize a lot of codes rather naturally.

What about keeping dynamism in the AD?

This of course then begs the question, is it possible to keep the full dynamism of the host language in the AD system? It is possible, but it is hard. This is what a lot of the Julia AD tools have focused on with source code transformations (along with Swift for TensorFlow). However, since source code is “for humans”, it can be a rather difficult level to algorithmically work on. Thus instead these tools work on lowered IR, where these lowered representations remove a lot of the “cruft” of syntax to give a much smaller support surface. This was the core of Zygote.jl’s approach where it saw that by acting on the SSA IR it could directly support control flow like while loops without unrolling them into sets of operations (like PyTorch or TensorFlow Eager) or only supporting a sublanguage of control flow (like Jax). This is essentially done by converting while loops and other dynamic constructs into static (source code) representations that have new lines of code in there for things like stacks that keep information about the forward pass (like which branch is taken), and then these stacks are accessed and used in the generated backwards pass. Thus what code is generated is not dynamic (a while loop forward gives a for loop in reverse), but the generated backwords pass is dynamic (because it uses the stack to tell it how many times to walk the for loop). This allows AD to have a single code for all branches (unlike the tape-building forms) and thus it can optimize more like TensorFlow but in a world where the dynamic control flow is not eliminated.

Well that sounds like the best of both worlds, so why isn’t everyone using it? There’s two factors involved in that. First, accepting that your AD will have to deal with the full dynamic nature of an entire programming language means accepting a much more difficult job. The whole purpose of the AD approaches in TensorFlow/PyTorch/Jax is for these constructs to be eliminated before the AD, so they have a much smaller surface of language support required. Because of this added complexity, this pretty much guarantees you cannot use Python because it’s such a crazy language in terms of what it allows with dynamism (fun fact, the Jax folks at Google Brain did have a Python source code transform AD at one point but it was scrapped essentially because of these difficulties), and so people working on these solutions flocked to languages with clear syntax that is easy for compilers to optimize, i.e. Julia and Swift. Python has most of the ML crowd, so that creates a barrier to entry.

But even then, the problem is still very hard. In Julia it was found that Zygote acts on too high of an IR, i.e. before compiler optimizations, which then requires you do AD on unoptimized code only delete most of the work later, and so it would be better for it to go even lower. This is the reason why the Diffractor.jl project started. But there’s even a reason to act lower, since some optimization only occurs at the LLVM level, which is why Julia developers started directly building an AD system that acts on LLVM’s IR itself known as Enzyme (note that while this project included members of the Julia Lab like Valentin, because it acts at the LLVM level it is applicable to any LLVM compiled language, such as C/C++ (Clang) or Rust). There is then a trade-off that occurs with source code transform methods as you go lower and lower in the IRs which I describe in a separate post. tl;dr there: Enzyme can act after compiler optimizations so much of the higher level information might be deleted (at least, without completion of dialects like MLIR which aren’t quite ready). Enzyme only sees the barest of low level code so it may not have the high level linear algebra definitions to do all of the linear algebra simplifications, like how XLA will fuse many matrix-vector multiplications into a matrix-matrix multiplication, since some of the function calls may have been inlined and deleted. Optimizing this remining loopy code to reach BLAS speeds is thus as hard as generating looping code that reaches BLAS speeds, and history shows this is hard but not impossible. Additionally, function calls to a nonlinear solver may have already been deleted, so optimized adjoints which outperform the direct differentiation of code, like in the case of Deep Equilibrium Models (DEQs), may end up less optimized. But that lowest level allows for very efficient scalar code differentiation and mutation support. On the other hand, Diffractor uses Julia’s typed IR so it can apply higher level rules easily and consistently, and in theory it can do transformations similar to XLA (i.e. keeping BLAS calls intact and fusing them). But writing such analyses on a fully dynamic compute IR is difficult enough that it has not been done. Tooling around escape analysis and shape propagation are being built to try and enable such optimizations, but the fact remains that it’s a lot more work to do it on a language IR instead of a sublanguage graph like XLA. In theory you could have compiler passes prove that a function is semi-static in the sense of XLA and get the same optimizations as Jax or TensorFlow, but that doesn’t happen today and it’s not easy to do. The future of Julia AD systems will likely mix the Enzyme and Diffractor approaches to tackle this issue, but the clear trade-off being made here is generality at the cost of implementation complexity.

The second factor, and probably the more damning one, is that most ML codes don’t actually use that much dynamism. Recurrent neural networks, transformers, convolutions, etc. all have simpler forms of dynamism which in some sense is quite static. That’s an important trade-off most people don’t always consider: why solve problems your users don’t have? The number of layers you have do not depend on the values coming out of the layers. Support for dynamism for ML workflows is thus mostly about convenience, not necessity. When algorithms do have dynamism, in most cases you can get away with wrapping it as an operation in the language, i.e. defining a function and defining the adjoint derivative for that function. This for example is how Jax supports ODEs even though adaptive ODE solvers require knowing the calculated values in order to determine the number of steps. You cannot differentiate an ODE code with Jax, but if you use an ODE solver with a defined adjoint you are okay. While this does mean that some algorithms are not possible with Jax (at least without forgoing a lot of optimizations), and algorithms where differntiating solvers is fundamentally different from adjoint definitions can limit which performance/stability trade-offs can be made (see the supplemental section 8 for details in the case of ODEs about stability of “discrete adjoints”]), these factors seem to be rather rare in standard ML use cases which is why most people haven’t bothered to learn a new programming language to get around these issues.

That leaves us where we are today. Are more ML algorithms of the future going to require handling more dynamic structures? Is optimizing scalar and mutating code going to be important for people using AD systems? The reason why I know this story so well is because the answer for my domain, scientific machine learning (SciML), is yes. Climate models use mutation because reallocating huge buffers would greatly effect performance. Adaptive solvers on stiff equations are a fact of life, so simple adjoints used in PyTorch and Jax are unstable and simply give Inf as the gradients in these cases. Time will tell whether this physics-informed, expert-guided, science-guided, scientific machine learning domain becomes standard, but hopefully this describes how all of the choices made here were not “better” or “worse”, but instead it’s all about domain-specific engineering trade-offs.

The post Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia appeared first on Stochastic Lifestyle.