MagicPIG: LSH Sampling for Efficient LLM Generation

要約

長いコンテキスト ウィンドウを持つ大規模言語モデル (LLM) が大きな注目を集めています。
ただし、再計算を避けるために保存される KV キャッシュがボトルネックになります。
注意がまばらであるという共通の洞察を活用するために、さまざまな動的スパースまたは TopK ベースの注意近似方法が提案されています。
この論文では、最初に、TopK のアテンション自体が、アテンションが常に期待ほどまばらであるとは限らないため、特定の下流タスクにおいて品質低下の影響を受けることを示します。
最も高い注意スコアを持つキーと値を選択するのではなく、理論的な保証を備えたサンプリングにより、注意出力のより適切な推定が提供されます。
LLM 生成においてサンプリングベースの近似を実用化するために、局所性敏感ハッシュ (LSH) に基づくヘテロジニアス システムである MagicPIG を提案します。
MagicPIG は、さまざまなタスクに対して高精度を維持しながら、アテンション計算の作業負荷を大幅に軽減します。
MagicPIG は LSH ハッシュ テーブルを保存し、CPU 上でアテンション計算を実行します。これにより、より長いコンテキストとより大きなバッチ サイズを高い近似精度で処理できます。
MagicPIG は、さまざまな GPU ハードウェア全体でデコード スループットを $1.9\sim3.9\times$ 向上させ、96,000 トークンのコンテキストを持つ Llama-3.1-8B-Instruct モデルの単一 RTX 4090 で 110 ミリ秒のデコード レイテンシを達成できます。
コードは \url{https://github.com/Infini-AI-Lab/MagicPIG} で入手できます。

要約(オリジナル)

Large language models (LLMs) with long context windows have gained significant attention. However, the KV cache, stored to avoid re-computation, becomes a bottleneck. Various dynamic sparse or TopK-based attention approximation methods have been proposed to leverage the common insight that attention is sparse. In this paper, we first show that TopK attention itself suffers from quality degradation in certain downstream tasks because attention is not always as sparse as expected. Rather than selecting the keys and values with the highest attention scores, sampling with theoretical guarantees can provide a better estimation for attention output. To make the sampling-based approximation practical in LLM generation, we propose MagicPIG, a heterogeneous system based on Locality Sensitive Hashing (LSH). MagicPIG significantly reduces the workload of attention computation while preserving high accuracy for diverse tasks. MagicPIG stores the LSH hash tables and runs the attention computation on the CPU, which allows it to serve longer contexts and larger batch sizes with high approximation accuracy. MagicPIG can improve decoding throughput by $1.9\sim3.9\times$ across various GPU hardware and achieve 110ms decoding latency on a single RTX 4090 for Llama-3.1-8B-Instruct model with a context of 96k tokens. The code is available at \url{https://github.com/Infini-AI-Lab/MagicPIG}.

arxiv情報

著者 Zhuoming Chen,Ranajoy Sadhukhan,Zihao Ye,Yang Zhou,Jianyu Zhang,Niklas Nolte,Yuandong Tian,Matthijs Douze,Leon Bottou,Zhihao Jia,Beidi Chen
発行日 2024-10-21 16:44:51+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク