HyperAttention: Long-context Attention in Near-Linear Time

要約

我々は、大規模言語モデル(LLM)で使われる長いコンテキストの複雑化によってもたらされる計算上の課題に対処するために、ハイパーアテンション(HyperAttention)と名付けられた近似アテンション機構を提示する。最近の研究によれば、注意行列のエントリが有界であるか、行列の安定ランクが低い場合を除き、最悪の場合、2次関数的な時間が必要である。我々は、(1)正規化された注目行列の最大列ノルム、(2)大きなエントリを検出して除去した後の正規化されていない注目行列の行ノルムの比、を測定する2つのパラメータを導入する。これらの細かいパラメータを用いて、問題の難しさを表現する。これまでの下界にもかかわらず、我々は、上記のパラメータが小さければ、行列が無限のエントリを持つ場合や安定ランクが大きい場合でも、線形時間サンプリングアルゴリズムを達成することができる。HyperAttentionは、他の高速な低レベル実装、特にFlashAttentionを容易に統合できるモジュール設計を特徴としている。経験的に、大きなエントリを識別するためにLocality Sensitive Hashing (LSH)を採用することで、HyperAttentionは既存の手法を凌駕し、FlashAttentionのような最先端のソリューションと比較して大幅な速度向上を実現している。我々は、HyperAttentionの経験的性能を様々な異なるロングコンテクスト長のデータセットで検証した。例えば、HyperAttentionは、32kコンテキスト長において、ChatGLM2の推論時間を50%高速化する一方で、当惑度は5.6から6.3に増加する。より大きなコンテキスト長、例えば131kでは、因果的マスキングにより、HyperAttentionは1つのアテンションレイヤーで5倍のスピードアップを提供する。

要約(オリジナル)

We present an approximate attention mechanism named HyperAttention to address the computational challenges posed by the growing complexity of long contexts used in Large Language Models (LLMs). Recent work suggests that in the worst-case scenario, quadratic time is necessary unless the entries of the attention matrix are bounded or the matrix has low stable rank. We introduce two parameters which measure: (1) the max column norm in the normalized attention matrix, and (2) the ratio of row norms in the unnormalized attention matrix after detecting and removing large entries. We use these fine-grained parameters to capture the hardness of the problem. Despite previous lower bounds, we are able to achieve a linear time sampling algorithm even when the matrix has unbounded entries or a large stable rank, provided the above parameters are small. HyperAttention features a modular design that easily accommodates integration of other fast low-level implementations, particularly FlashAttention. Empirically, employing Locality Sensitive Hashing (LSH) to identify large entries, HyperAttention outperforms existing methods, giving significant speed improvements compared to state-of-the-art solutions like FlashAttention. We validate the empirical performance of HyperAttention on a variety of different long-context length datasets. For example, HyperAttention makes the inference time of ChatGLM2 50\% faster on 32k context length while perplexity increases from 5.6 to 6.3. On larger context length, e.g., 131k, with causal masking, HyperAttention offers 5-fold speedup on a single attention layer.

arxiv情報

著者 Insu Han,Rajesh Jayaram,Amin Karbasi,Vahab Mirrokni,David P. Woodruff,Amir Zandieh
発行日 2023-12-01 17:43:06+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, DeepL

カテゴリー: cs.AI, cs.LG パーマリンク