Optimizing Distributed ML Communication with Fused Computation-Collective Operations
Machine learning models are distributed across multiple nodes using numerous parallelism strategies. The resulting collective communication is often on the critical path due to a lack of independent coarse-grain computation kernels available to execute. In this work, we propose fusing computation wi...
Gespeichert in:
| Veröffentlicht in: | SC24: International Conference for High Performance Computing, Networking, Storage and Analysis S. 1 - 17 |
|---|---|
| Hauptverfasser: | , , |
| Format: | Tagungsbericht |
| Sprache: | Englisch |
| Veröffentlicht: |
IEEE
17.11.2024
|
| Schlagworte: | |
| Online-Zugang: | Volltext |
| Tags: |
Tag hinzufügen
Keine Tags, Fügen Sie den ersten Tag hinzu!
|
| Zusammenfassung: | Machine learning models are distributed across multiple nodes using numerous parallelism strategies. The resulting collective communication is often on the critical path due to a lack of independent coarse-grain computation kernels available to execute. In this work, we propose fusing computation with its subsequent collective communication and leverage GPUs' massive parallelism, along with GPU-initiated communication, to overlap communication and computation. Specifically threadblocks/workgroups (WGs) immediately communicate their results to remote GPUs after completing their computation, while other WGs within the same kernel perform computation. We developed three prototype fused operators (embedding+All-toAll, GEMV+AllReduce, and GEMM+All-to-All) to address the communication overheads in DLRM, Transformers and MoE model architectures. We expose fused kernels as new PyTorch operators, as well as extend the Triton framework to demonstrate their practicality. Our evaluations show our approach effectively overlaps communication with computations, subsequently reducing their combined execution time achieving \mathbf{1 2 \%} - \mathbf{3 1 \%} lower execution time across all three operators. |
|---|---|
| DOI: | 10.1109/SC41406.2024.00094 |