Squeezed Attention: Accelerating Long Context Length LLM Inference

要約

新しい大規模言語モデル (LLM) アプリケーションでは、ドキュメント分析やコード生成などの複雑な下流タスクを実行するために長い入力プロンプトが必要です。
このような長いコンテキスト長のアプリケーションでは、推論コストがシーケンスの長さに比例して増加するため、入力プロンプトの長さが推論効率の点で大きな課題となります。
ただし、これらのアプリケーションの多くでは、プロンプト内のコンテキストの多くがさまざまなユーザー入力にわたって固定されているため、ユーザー入力を受信したときにオフライン最適化を実行して、ユーザー入力を迅速に処理する機会が得られます。
この研究では、入力プロンプトの大部分が固定されている LLM アプリケーションを高速化するメカニズムとして Squeezed Attend を提案します。
まず、オフラインで K 平均法クラスタリングを利用して、意味論的な類似性に基づいて固定コンテキストのキーをグループ化し、各クラスターを単一の重心値で表します。
推論中に、ユーザー入力からのクエリ トークンを重心と比較して、固定コンテキストからのどのキーが意味的に関連しており、推論中にロードする必要があるかを予測します。
次に、固定コンテキストからのこれらの重要なキーのみを使用して正確なアテンションを計算し、それによって帯域幅と計算コストを削減します。
また、重要なキーを識別するために階層的な重心検索を使用するように方法を拡張しました。これにより、コンテキストの長さに関して、注意の複雑さを線形から対数に減らすことができます。
セントロイド比較用に最適化された Triton カーネルと、重要なキーを含むスパース FlashAttend を実装し、長いコンテキスト推論のプレフィル フェーズと生成フェーズの両方で 4 倍以上の高速化を達成しました。
さらに、LongBench を含むさまざまなロングコンテキスト ベンチマークでこの手法を広範囲に評価しました。このベンチマークでは、精度を損なうことなく KV キャッシュ バジェットを 3 倍削減し、さまざまなモデルで精度ギャップが 0.5 ポイント未満で最大 8 倍の削減を達成しました。

要約(オリジナル)

Emerging Large Language Model (LLM) applications require long input prompts to perform complex downstream tasks like document analysis and code generation. For these long context length applications, the length of the input prompt poses a significant challenge in terms of inference efficiency since the inference costs increase linearly with sequence length. However, for many of these applications, much of the context in the prompt is fixed across different user inputs, thereby providing the opportunity to perform offline optimizations to process user inputs quickly, as they are received. In this work, we propose Squeezed Attention as a mechanism to accelerate LLM applications where a large portion of the input prompt is fixed. We first leverage K-means clustering offline to group the keys for the fixed context based on semantic similarity and represent each cluster with a single centroid value. During inference, we compare query tokens from the user input with the centroids to predict which of the keys from the fixed context are semantically relevant and need to be loaded during inference. We then compute exact attention using only these important keys from the fixed context, thereby reducing bandwidth and computational costs. We also extend our method to use a hierarchical centroid lookup to identify important keys, which can reduce the complexity of attention from linear to logarithmic with respect to the context length. We implement optimized Triton kernels for centroid comparison and sparse FlashAttention with important keys, achieving more than 4x speedups during both the prefill and generation phases for long-context inference. Furthermore, we have extensively evaluated our method on various long-context benchmarks including LongBench, where it achieves a 3x reduction in KV cache budget without accuracy loss and up to an 8x reduction with <0.5 point accuracy gap for various models.

arxiv情報

著者 Coleman Hooper,Sehoon Kim,Hiva Mohammadzadeh,Monishwaran Maheswaran,June Paik,Michael W. Mahoney,Kurt Keutzer,Amir Gholami
発行日 2024-11-14 18:54:19+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL パーマリンク