Optimizing Distributed ML Communication with Fused Computation-Collective Operations

Machine learning models are distributed across multiple nodes using numerous parallelism strategies. The resulting collective communication is often on the critical path due to a lack of independent coarse-grain computation kernels available to execute. In this work, we propose fusing computation wi...

Celý popis

Uloženo v:
Podrobná bibliografie
Vydáno v:SC24: International Conference for High Performance Computing, Networking, Storage and Analysis s. 1 - 17
Hlavní autoři: Punniyamurthy, Kishore, Hamidouche, Khaled, Beckmann, Bradford M.
Médium: Konferenční příspěvek
Jazyk:angličtina
Vydáno: IEEE 17.11.2024
Témata:
On-line přístup:Získat plný text
Tagy: Přidat tag
Žádné tagy, Buďte první, kdo vytvoří štítek k tomuto záznamu!
Popis
Shrnutí:Machine learning models are distributed across multiple nodes using numerous parallelism strategies. The resulting collective communication is often on the critical path due to a lack of independent coarse-grain computation kernels available to execute. In this work, we propose fusing computation with its subsequent collective communication and leverage GPUs' massive parallelism, along with GPU-initiated communication, to overlap communication and computation. Specifically threadblocks/workgroups (WGs) immediately communicate their results to remote GPUs after completing their computation, while other WGs within the same kernel perform computation. We developed three prototype fused operators (embedding+All-toAll, GEMV+AllReduce, and GEMM+All-to-All) to address the communication overheads in DLRM, Transformers and MoE model architectures. We expose fused kernels as new PyTorch operators, as well as extend the Triton framework to demonstrate their practicality. Our evaluations show our approach effectively overlaps communication with computations, subsequently reducing their combined execution time achieving \mathbf{1 2 \%} - \mathbf{3 1 \%} lower execution time across all three operators.
DOI:10.1109/SC41406.2024.00094