FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

要約

Transformers をより長いシーケンス長に拡張することは、ここ数年の大きな問題であり、言語モデリングと高解像度画像理解のパフォーマンスを向上させるだけでなく、コード、オーディオ、ビデオの生成における新しいアプリケーションの可能性を約束しています。
アテンション層は、ランタイムとメモリがシーケンスの長さに応じて二次関数的に増加するため、より長いシーケンスにスケーリングする際の主なボトルネックになります。
FlashAttendant は、非対称 GPU メモリ階層を活用して、近似なしで大幅なメモリ節約 (二次関数ではなく線形) と実行時間の高速化 (最適化されたベースラインと比較して 2 ~ 4 倍) を実現します。
ただし、FlashAttendant は依然として最適化行列乗算 (GEMM) 演算ほど高速ではなく、理論上の最大 FLOPs/s の 25 ~ 40\% にしか達しません。
非効率性の原因は、GPU 上の異なるスレッド ブロックとワープ間の最適ではない作業の分割によるもので、占有率の低下または不必要な共有メモリの読み取り/書き込みが発生していることがわかります。
私たちは、これらの問題に対処するために、より適切な作業分割を備えた FlashAttendant-2 を提案します。
具体的には、(1) アルゴリズムを微調整して非 matmul FLOP の数を削減します。(2) アテンションの計算を、たとえ単一のヘッドであっても、異なるスレッド ブロック間で並列化して占有率を高めます。(3) 各スレッド ブロック内で、
ワープ間で作業を分散して、共有メモリを介した通信を削減します。
これらにより、FlashAttendant と比較して約 2$\倍$ の高速化が実現し、A100 の理論上の最大 FLOP/秒の 50 ~ 73\% に達し、GEMM 操作の効率に近づきます。
GPT スタイルのモデルをトレーニングするためにエンドツーエンドで使用すると、FlashAttendant-2 は A100 GPU あたり最大 225 TFLOPs/s のトレーニング速度に達することが経験的に検証されています (72% のモデル FLOP 使用率)。

要約(オリジナル)

Scaling Transformers to longer sequence lengths has been a major problem in the last several years, promising to improve performance in language modeling and high-resolution image understanding, as well as to unlock new applications in code, audio, and video generation. The attention layer is the main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in the sequence length. FlashAttention exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4$\times$ compared to optimized baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40\% of the theoretical maximum FLOPs/s. We observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We propose FlashAttention-2, with better work partitioning to address these issues. In particular, we (1) tweak the algorithm to reduce the number of non-matmul FLOPs (2) parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy, and (3) within each thread block, distribute the work between warps to reduce communication through shared memory. These yield around 2$\times$ speedup compared to FlashAttention, reaching 50-73\% of the theoretical maximum FLOPs/s on A100 and getting close to the efficiency of GEMM operations. We empirically validate that when used end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s per A100 GPU (72\% model FLOPs utilization).

arxiv情報

著者 Tri Dao
発行日 2023-07-17 17:50:36+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.LG パーマリンク