Optimized Multi-Token Joint Decoding with Auxiliary Model for LLM Inference

要約

大規模言語モデル (LLM) は、さまざまなタスクにわたって目覚ましい成功を収めていますが、その推論プロセスは、各デコード ステップでの単一トークンの生成によるかなりの時間とエネルギーの需要によって妨げられています。
投機的デコードなどの以前の方法では、ステップごとに複数のトークンを生成することでこれらの非効率性を軽減していますが、各トークンは引き続き単一トークンの配布によって生成されるため、効率性を向上させることなく速度を向上させることができます。
対照的に、私たちの取り組みは、推論速度の向上と出力の効率の向上を同時に実現します。
私たちは、反復ごとに結合分布から複数のトークンを生成するマルチトークン結合デコード (MTJD) を検討します。これにより、理論的には複雑さが軽減され、タスクのパフォーマンスが向上します。
ただし、MTJD は複数のトークンの共同配布からのサンプリングのコストが高いという問題があります。
投機的デコードからインスピレーションを得て、MTJD を高速化するために設計された新しいフレームワークであるマルチトークン支援デコード (MTAD) を紹介します。
MTAD は、より小さな補助モデルを利用して、より大きなモデルの結合分布を近似し、この近似の精度を保証するだけでなく、従来の投機的デコードよりもデコード効率を向上させる検証メカニズムを組み込んでいます。
理論的には、MTAD が誤差を制限して正確な MTJD に近似することを示します。
さまざまなタスクにわたって 13B から 70B のパラメーターにわたる Llama-2 および OPT モデルを使用した実証評価により、標準の単一トークン サンプリングと比較して、MTAD が複雑さを 21.2% 削減し、ダウンストリームのパフォーマンスを向上させることが明らかになりました。
さらに、MTAD は従来の投機的復号方法と比べて 1.42 倍の高速化を実現し、消費エネルギーは 1.54 倍少なくなります。
これらの結果は、マルチトークンの共同デコードを効果的かつ効率的に行い、LLM のより持続可能で高性能な展開を促進する MTAD の能力を強調しています。

要約(オリジナル)

Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also improves the decoding efficiency over conventional speculative decoding. Theoretically, we demonstrate that MTAD closely approximates exact MTJD with bounded error. Empirical evaluations using Llama-2 and OPT models ranging from 13B to 70B parameters across various tasks reveal that MTAD reduces perplexity by 21.2% and improves downstream performance compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42x speed-up and consumes 1.54x less energy than conventional speculative decoding methods. These results highlight MTAD’s ability to make multi-token joint decoding both effective and efficient, promoting more sustainable and high-performance deployment of LLMs.

arxiv情報

著者 Zongyue Qin,Ziniu Hu,Zifan He,Neha Prakriya,Jason Cong,Yizhou Sun
発行日 2024-10-02 16:14:09+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク