Efficiently Dispatching Flash Attention For Partially Filled Attention Masks

要約

トランスフォーマーはさまざまなアプリケーションで広く使用されており、その多くはスパースまたは部分的に満たされたアテンション行列を生成します。
例には、アテンションの 2 次複雑さを軽減するように設計されたアテンション マスク、シーケンス パッキング技術、MEDUSA での高速検証のためのツリー マスキングなどの最近のイノベーションが含まれます。
これらの行列には固有の疎性があるにも関わらず、最先端のアルゴリズムであるフラッシュ アテンションは、行列が密であるかのように 2 次の複雑さで処理します。
このペーパーでは、マスクを認識させることでフラッシュ アテンションを強化する非常に効率的な変更であるバイナリ ブロック マスキングを紹介します。
さらに、2 つの最適化を提案します。1 つは連続した非ゼロ パターンを持つマスクに合わせたもの、もう 1 つは非常にまばらなマスクに合わせたものです。
現実世界のシナリオから派生したアテンション マスクに関する実験では、ランタイムが最大 9 倍向上することが実証されました。
この実装は、さらなる研究と応用を促進するために一般に公開されます。

要約(オリジナル)

Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices. Examples include attention masks designed to reduce the quadratic complexity of attention, sequence packing techniques, and recent innovations like tree masking for fast validation in MEDUSA. Despite the inherent sparsity in these matrices, the state-of-the-art algorithm Flash Attention still processes them with quadratic complexity as though they were dense. In this paper, we introduce Binary Block Masking, a highly efficient modification that enhances Flash Attention by making it mask-aware. We further propose two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks. Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement. The implementation will be publicly released to foster further research and application.

arxiv情報

著者 Agniv Sharma,Jonas Geiping
発行日 2024-09-24 12:56:13+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.CL, cs.LG パーマリンク