Optimizing Distributed ML Communication with Fused Computation-Collective Operations

Machine learning models are distributed across multiple nodes using numerous parallelism strategies. The resulting collective communication is often on the critical path due to a lack of independent coarse-grain computation kernels available to execute. In this work, we propose fusing computation wi...

Ausführliche Beschreibung

Gespeichert in:
Bibliographische Detailangaben
Veröffentlicht in:SC24: International Conference for High Performance Computing, Networking, Storage and Analysis S. 1 - 17
Hauptverfasser: Punniyamurthy, Kishore, Hamidouche, Khaled, Beckmann, Bradford M.
Format: Tagungsbericht
Sprache:Englisch
Veröffentlicht: IEEE 17.11.2024
Schlagworte:
Online-Zugang:Volltext
Tags: Tag hinzufügen
Keine Tags, Fügen Sie den ersten Tag hinzu!
Beschreibung
Zusammenfassung:Machine learning models are distributed across multiple nodes using numerous parallelism strategies. The resulting collective communication is often on the critical path due to a lack of independent coarse-grain computation kernels available to execute. In this work, we propose fusing computation with its subsequent collective communication and leverage GPUs' massive parallelism, along with GPU-initiated communication, to overlap communication and computation. Specifically threadblocks/workgroups (WGs) immediately communicate their results to remote GPUs after completing their computation, while other WGs within the same kernel perform computation. We developed three prototype fused operators (embedding+All-toAll, GEMV+AllReduce, and GEMM+All-to-All) to address the communication overheads in DLRM, Transformers and MoE model architectures. We expose fused kernels as new PyTorch operators, as well as extend the Triton framework to demonstrate their practicality. Our evaluations show our approach effectively overlaps communication with computations, subsequently reducing their combined execution time achieving \mathbf{1 2 \%} - \mathbf{3 1 \%} lower execution time across all three operators.
DOI:10.1109/SC41406.2024.00094