Taku Ito, Luca Cocchi, et al.
ICML 2025
Representing control flows in machine learning (ML) compilers and intermediate representations (IRs) has been a long-standing problem, with many ML compilers opting to avoid supporting them altogether. A common practice is to trace the model and get a straight-line computational graph, for example, by specializing if predicates or by unrolling loops. Although this strategy may work well in some cases, it can result in performance issues and long compilation times. Furthermore, it is problematic in cases where the control flow of the program depends on data or a dynamic shape.
In this paper, we present the PyTorch control flow operator library, which addresses the challenge through the introduction of five control flow operators to PyTorch. We will explain their usage, present use cases for large lanuage models (LLMs) and demonstrate their benefits from the perspective of PyTorch 2 (PT2) such as the ability to capture control flows in the IR with cond and reduce compilation times when rewriting loops with map, scan, associaitve_scan and while_loop.