Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels

要約

ゲーティングを伴う線形RNNは最近、言語モデリングのトランスと比較して競争力のあるパフォーマンスを実証しました。
シーケンスの長さの線形計算スケーリングは、変圧器よりも理論的なランタイムの利点を提供しますが、これらの利点が実際には非常に効率的なフラッシュ注意カーネルに依存しているため、最適化されたカスタムカーネルが必要です。
線形RNNのチャンクワイズパラレル配合を活用すると、フラッシュ線形注意(FLA)は、入力シーケンスのチャンクを並列化することにより、線形RNNカーネルがフラッシュの注意よりも高速であることを示しています。
ただし、FLAのチャンクサイズは限られているため、多くの中間状態をGPU​​メモリで実現する必要があります。
これにより、算術強度が低くなり、特に長いコンテキストの事前トレーニングでは、メモリ消費量が高くなり、IOコストが発生します。
この作業では、各チャンク内に追加のレベルのシーケンス並列化を導入することにより、任意の大きなチャンクサイズを可能にする、線形RNNの新しいカーネルアルゴリズムであるタイル張りのフラッシュリニアメント(TFLA)を提示します。
まず、MLSTMを使用してXLSTMにTFLAを適用します。
第二に、Sigmoid入力ゲートを備えたMLSTMバリアントを提案し、同じ言語モデリングパフォーマンスでさらに速いカーネルランタイムの計算を削減します。
スピードベンチマークでは、TFLAに基づいた新しいMLSTMカーネルが、高度に最適化されたフラッシュの注意、線形注意、およびマンバカーネルを上回り、効率的な長いコンテキストシーケンスモデリングプリミティブの新しい最新技術を設定することを示します。

要約(オリジナル)

Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient Flash Attention kernels. Leveraging the chunkwise-parallel formulation of linear RNNs, Flash Linear Attention (FLA) shows that linear RNN kernels are faster than Flash Attention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This leads to low arithmetic intensity and causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM. Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.

arxiv情報

著者 Maximilian Beck,Korbinian Pöppel,Phillip Lippe,Sepp Hochreiter
発行日 2025-03-18 16:09:47+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク