Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts

要約

トランスモデルは、二次時間と線形メモリの複雑さのために、長いコンテキスト推論と格闘しています。
再発メモリ変圧器(RMTS)は、漸近コストを線形時間と一定のメモリ使用量に削減することにより、ソリューションを提供します。
ただし、メモリの更新メカニズムは順次実行につながり、パフォーマンスボトルネックを引き起こします。
正確な再発を維持しながら、RMTのセグメント間の並列性を解き放つスケジューリングスキームである斜めのバッチを導入します。
このアプローチは、順次制約を排除し、複雑なバッチとパイプラインの技術を使用しない単一の長いコンテキスト入力でも効率的なGPU推論を可能にします。
この手法は純粋にランタイム計算の再注文であるため、既存のRMTモデルは再訓練なしでそれを採用します。
Llama-1B ARMTモデルに適用される対角線バッチは、131,072トークンシーケンスでのシーケンシャルRMT実装で標準のフルアテンションで3.3倍のスピードアップと1.8倍のスピードアップをもたらします。
連続したボトルネックを削除することにより、対角線バッチは推論コストと遅延を削減し、それによりRMTを実世界の長いコンテキストアプリケーションの実用的なソリューションとして強化します。

要約(オリジナル)

Transformer models struggle with long-context inference due to their quadratic time and linear memory complexity. Recurrent Memory Transformers (RMTs) offer a solution by reducing the asymptotic cost to linear time and constant memory usage. However, their memory update mechanism leads to sequential execution, causing a performance bottleneck. We introduce Diagonal Batching, a scheduling scheme that unlocks parallelism across segments in RMTs while preserving exact recurrence. This approach eliminates the sequential constraint, enabling efficient GPU inference even for single long-context inputs without complex batching and pipelining techniques. Because the technique is purely a run-time computation reordering, existing RMT models adopt it with no retraining. Applied to a LLaMA-1B ARMT model, Diagonal Batching yields a 3.3x speedup over standard full-attention LLaMA-1B and a 1.8x speedup over the sequential RMT implementation on 131,072-token sequences. By removing sequential bottleneck, Diagonal Batching reduces inference cost and latency, thereby strengthening RMTs as a practical solution for real-world, long-context applications.

arxiv情報

著者 Danil Sivtsov,Ivan Rodkin,Gleb Kuzmin,Yuri Kuratov,Ivan Oseledets
発行日 2025-06-05 16:43:48+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク