Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

要約

大規模言語モデル (LLM) の推論プロセスは、自己回帰デコード プロセスに並列性がないために制限されることが多く、その結果、ほとんどの操作がアクセラレータのメモリ帯域幅によって制限されます。
この問題に対処するために、投機的デコードなどの方法が提案されていますが、その実装は、別のドラフト モデルの取得と維持に伴う課題によって妨げられています。
この論文では、追加のデコード ヘッドを追加して後続の複数のトークンを並行して予測することで LLM 推論を強化する効率的な方法である Medusa を紹介します。
Medusa はツリーベースのアテンション メカニズムを使用して複数の継続候補を構築し、各デコード ステップでそれらを同時に検証します。
Medusa は、並列処理を活用することで、必要なデコード ステップの数を大幅に削減しながら、シングル ステップのレイテンシーの観点から最小限のオーバーヘッドのみを導入します。
さまざまなユースケースのニーズを満たすために、Medusa の 2 つのレベルの微調整手順を示します。 Medusa-1: Medusa は凍結されたバックボーン LLM 上で直接微調整され、ロスレス推論の高速化が可能になります。
Medusa-2: Medusa はバックボーン LLM とともに微調整されており、Medusa ヘッドの予測精度の向上と高速化が可能ですが、バックボーン モデルの機能を維持する特別なトレーニング レシピが必要です。
さらに、トレーニング データが利用できない状況に対処するための自己蒸留や、生成品質を維持しながら受け入れ率を高めるための典型的な受け入れスキームなど、Medusa の有用性を改善または拡張するいくつかの拡張機能を提案します。
さまざまなサイズのモデルとトレーニング手順で Medusa を評価します。
私たちの実験では、Medusa-1 が生成品質を損なうことなく 2.2 倍を超える高速化を達成できるのに対し、Medusa-2 はさらに 2.3 ~ 3.6 倍まで高速化することを示しています。

要約(オリジナル)

The inference process in Large Language Models (LLMs) is often limited due to the absence of parallelism in the auto-regressive decoding process, resulting in most operations being restricted by the memory bandwidth of accelerators. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa introduces only minimal overhead in terms of single-step latency while substantially reducing the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model’s capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.

arxiv情報

著者 Tianle Cai,Yuhong Li,Zhengyang Geng,Hongwu Peng,Jason D. Lee,Deming Chen,Tri Dao
発行日 2024-01-19 15:48:40+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク