要約
大規模な言語モデル(LLMS)は、少数のショットプロンプト、マルチステップの推論、投機的デコードなどを含む、トークンの共有プレフィックスを使用してツリー構造で複数の生成コールを処理する複雑なタスクにますます採用されています。ただし、ツリーベースの用途向けの既存の推論システムは、注意計算中のクエリーとKVキャッシュの不適切な分配により効果的ではありません。
これにより、2つの主要な問題が発生します。(1)共有プレフィックスのKVキャッシュのメモリアクセスの欠如(IO)再利用、および(2)荷重のバランスの低下。
これらの課題に対処するために、deft(フラッシュツリーアテンションでデコードする)を提案します。これは、接頭辞が認識され、負荷バランスの取れたKVキャッシュパーティションを備えたハードウェア効率の高い注意アルゴリズムです。
DEFTは、注意計算で共有プレフィックスのKVキャッシュを繰り返しロードすることを避ける方法であるKVガイドグループを介した注意計算中に、KVキャッシュの読み取り/書き込み操作の数を減らします。
さらに、平坦化されたツリーKV分割を提案します。これは、計算冗長性がほとんどなく、パーティション全体にKVキャッシュの分布を保証し、注意計算中のGPU使用を強化するメカニズムを提案します。
注意計算中に部分的な結果のために73-99%kVキャッシュIOとほぼ100%IOを減らすことにより、DEFTは、最先端の注意アルゴリズムと比較して、3つの実用的なツリーベースのワークロードにわたってエンドツーエンド/注意レイテンシで最大2.23/3.59xスピードアップを達成します。
私たちのコードは、https://github.com/lins-lab/deftで入手できます。
要約(オリジナル)
Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.
arxiv情報
著者 | Jinwei Yao,Kaiqi Chen,Kexun Zhang,Jiaxuan You,Binhang Yuan,Zeke Wang,Tao Lin |
発行日 | 2025-03-07 17:47:42+00:00 |
arxivサイト | arxiv_id(pdf) |
提供元, 利用サービス
arxiv.jp, Google