HyperAttention: Long-context Attention in Near-Linear Time

要約

我々は、大規模言語モデル (LLM) で使用される長いコンテキストの複雑さの増大によって引き起こされる計算上の課題に対処するために、HyperAttendant という名前の近似アテンション メカニズムを紹介します。
最近の研究では、注意行列のエントリが制限されているか、行列の安定ランクが低い場合を除き、最悪のシナリオでは二次時間が必要であることが示唆されています。
(1) 正規化されたアテンション行列の最大列ノルム、(2) 大きなエントリを検出して削除した後の非正規化アテンション行列の行ノルムの比率を測定する 2 つのパラメーターを導入します。
これらのきめの細かいパラメータを使用して、問題の難しさを把握します。
以前の下限にもかかわらず、上記のパラメータが小さければ、行列に無制限のエントリがある場合や大きな安定したランクがある場合でも、線形時間サンプリング アルゴリズムを達成できます。
Hypertention は、他の高速な低レベル実装、特に FlashAttend の統合に簡単に対応できるモジュール設計を特徴としています。
経験的には、Locality Sensitive Hashing (LSH) を使用して大きなエントリを識別することで、HyperAttendant は既存の方法よりも優れたパフォーマンスを発揮し、FlashAttendant のような最先端のソリューションと比較して速度が大幅に向上します。
さまざまなコンテキスト長の長いデータセットに対する HyperAttendant の経験的なパフォーマンスを検証します。
たとえば、Hypertention を使用すると、コンテキスト長 32k で ChatGLM2 の推論時間が 50\% 高速になり、パープレキシティが 5.6 から 6.3 に増加します。
より大きなコンテキスト長 (例: 131k) では、因果マスキングを使用すると、Hypertention は単一のアテンション レイヤーで 5 倍の高速化を実現します。

要約(オリジナル)

We present an approximate attention mechanism named HyperAttention to address the computational challenges posed by the growing complexity of long contexts used in Large Language Models (LLMs). Recent work suggests that in the worst-case scenario, quadratic time is necessary unless the entries of the attention matrix are bounded or the matrix has low stable rank. We introduce two parameters which measure: (1) the max column norm in the normalized attention matrix, and (2) the ratio of row norms in the unnormalized attention matrix after detecting and removing large entries. We use these fine-grained parameters to capture the hardness of the problem. Despite previous lower bounds, we are able to achieve a linear time sampling algorithm even when the matrix has unbounded entries or a large stable rank, provided the above parameters are small. HyperAttention features a modular design that easily accommodates integration of other fast low-level implementations, particularly FlashAttention. Empirically, employing Locality Sensitive Hashing (LSH) to identify large entries, HyperAttention outperforms existing methods, giving significant speed improvements compared to state-of-the-art solutions like FlashAttention. We validate the empirical performance of HyperAttention on a variety of different long-context length datasets. For example, HyperAttention makes the inference time of ChatGLM2 50\% faster on 32k context length while perplexity increases from 5.6 to 6.3. On larger context length, e.g., 131k, with causal masking, HyperAttention offers 5-fold speedup on a single attention layer.

arxiv情報

著者 Insu Han,Rajesh Jarayam,Amin Karbasi,Vahab Mirrokni,David P. Woodruff,Amir Zandieh
発行日 2023-10-09 17:05:25+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク