DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training

要約

FlashAttendant (Dao、2023) は、単一 GPU でのトランスフォーマーベースの大規模言語モデル (LLM) のトレーニングにおいて、二次ピーク メモリ使用量を線形に効果的に削減します。
このペーパーでは、ロングコンテキスト LLM トレーニング用に最適化された分散メモリ効率の高いアテンション メカニズムである DISTFLASHATTN を紹介します。
私たちは、トークンレベルのワークロードバランシング、重複するキーと値の通信、再実体化を意識した勾配チェックポイントアルゴリズムという 3 つの主要な技術を提案します。
Llama-7B および 32K から 512K までの配列長を持つバリアントで DISTFLASHATTN を評価します。
DISTFLASHATTN は、Ring Self-Attend と比較して 8 倍長いシーケンス、4.45 ~ 5.64 倍の高速化、FlashAttendant を使用する Megatron-LM と比較して 2 ~ 8 倍長いシーケンス、1.24 ~ 2.01 倍の高速化を実現します。
最近の Ring Attendant や DeepSpeed-Ulysses と比較して、1.67 倍および 1.26 ~ 1.88 倍の高速化を実現します。
コードは https://github.com/RulinShao/LightSeq で入手できます。

要約(オリジナル)

FlashAttention (Dao, 2023) effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU. In this paper, we introduce DISTFLASHATTN, a distributed memory-efficient attention mechanism optimized for long-context LLMs training. We propose three key techniques: token-level workload balancing, overlapping key-value communication, and a rematerialization-aware gradient checkpointing algorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequence lengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 – 5.64x speedup compared to Ring Self-Attention, 2 – 8x longer sequences, 1.24 – 2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67x and 1.26 – 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses. Code is available at https://github.com/RulinShao/LightSeq.

arxiv情報

著者 Dacheng Li,Rulin Shao,Anze Xie,Eric P. Xing,Xuezhe Ma,Ion Stoica,Joseph E. Gonzalez,Hao Zhang
発行日 2024-03-31 21:11:08+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.DC, cs.LG パーマリンク