Above differences tell us something but why would Tensorflow go with AOT and Pytorch with JIT. Let’s try to understand it further.
In AOT, when we are compiling the whole program, we can do very specific optimizations. Fuse many kernels, combine few computations, transfer the data without leaving the device from one compute unit to another. Store data in a compact fashion. Skip some unnecessary computations (dead-code elimination) and store the whole program and split it in an optimized fashion on device that the data and code locality favors the overall performance latency and throughput.
Now the problem with that is,
(1) if I want to see the result of the intermediate layer, then it is possible that AOT might have optimized it away or modified it in such a way that it is no longer in its same form or type.
(2) Or if I change one parameter, let’s say hyperparameter for Adam optimizer, then the whole program will be compiled again even though only part of the program is modified.
(3) Input data in some ML systems can be of different shapes of types and AOT is compiled mostly for static shapes and thus has sub-optimal memory allocation and lacks corresponding optimizations as well.
(AOT sometimes handles dynamic shapes by compiling for max shape and then padding leftover colums which waste resources)
But with JIT, we are only compiling part of the program when we try to execute it. Each layer which is executed has results available for further analysis. They can easily be printed using python libraries (pytorch easily integrates with python). JIT does dynamic compilation and recompiles only the modified methods when it executes them. This is called “Rapid Prototyping”.
JIT also takes into account the shapes, types of input while compiling and thus optimizes the generated graph. It can thus execute many different shape/type inputs in optimized fashion. It also saves pre-compiled methods based on previously seen input shapes/types in cache and reuses them if encounters similar inputs again. This is “Dynamic Code Compilation”
Now one of the important concepts in JIT which makes JIT even more efficient is the “Dynamic Tracing”. This helps to perform dynamic control flow, optimize based on data flow and better handling of loop optimizations. It also helps with better memory allocation and efficient utilization of resources. JIT does it by tracing the execution path. It thus captures the sequence of operations performed including function calls, variable assignment and control flow statements. Based on this tracing, it creates a computation graph which captures the dependencies between operations. This graph is called DAG (Directed Acyclic Graph). Based on the tracing, optimizations are performed on this graph which in many cases are better than AOT.
Be the first to comment