A 4D Hybrid Algorithm to Scale Parallel Training to Thousands of GPUs

要約

分散システム上で最先端のニューラル ネットワークをトレーニングする場合、多額の通信コストが重大なボトルネックになります。
この論文では、Agarwal の行列乗算アルゴリズムに触発された、新しい 4 次元 (4D) 並列化アプローチである AxoNN を紹介します。深層学習におけるテンソル計算の並列化に使用されます。AxoNN は、通信オーバーヘッドを最小限に抑えるために 2 つの重要な戦略を採用しています。
まず、コストのかかる集合操作 (reduce-scatter、all-gather、all-reduce) と計算をオーバーラップさせることで通信を最適化します。
200 億パラメータのトランスフォーマー モデルを使用した実験では、これらの最適化により 53\% 近くの改善が得られることが実証されました。
次に、ユーザーが 4D アルゴリズムによって定義された広大な検索空間内で通信を最小限に抑える構成を特定するのを支援する分析モデルを紹介します。
このモデルは、特定のトレーニング ワークロードの調整プロセスを簡素化することで、実践者に力を与えます。
Perlmutter の 1,024 GPU で 800 億のパラメーター モデルをトレーニングした場合、AxoNN は最先端のフレームワークである Megatron-LM を 26% 大幅に上回りました。
さらに、理論上のピーク FLOP/秒の 57% を達成します。

要約(オリジナル)

Large communication costs are a critical bottleneck in training state-of-the-art neural networks on distributed systems. This paper introduces AxoNN, a novel four-dimensional (4D) parallelization approach, inspired by Agarwal’s algorithm for matrix multiplication, for parallelizing tensor computations in deep learning, AxoNN employs two key strategies to minimize communication overhead. First, we optimize communication by overlapping expensive collective operations (reduce-scatter, all-gather, all-reduce) with computations. Our experiments with a 20-billion parameter transformer model demonstrate that these optimizations deliver nearly 53\% improvement. Second, we present an analytical model to assist users in identifying communication-minimizing configurations within the vast search space defined by our 4D algorithm. This model empowers practitioners by simplifying the tuning process for their specific training workloads. When training an 80-billion parameter model on 1024 GPUs of Perlmutter, AxoNN surpasses Megatron-LM, a state-of-the-art framework, by a significant 26%. Additionally, it achieves 57% of the theoretical peak FLOP/s.

arxiv情報

著者 Siddharth Singh,Prajwal Singhania,Aditya K. Ranjan,Zack Sating,Abhinav Bhatele
発行日 2024-03-27 17:47:56+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.DC, cs.LG, cs.PF パーマリンク