FlashDecoding++: Faster Large Language Model Inference on GPUs

要約

大規模言語モデル (LLM) はさまざまなドメインでますます重要になっています。
ただし、LLM 推論の高速化においては、次の課題がまだ解決されていません。 (1) 同期された部分的なソフトマックス更新。
ソフトマックス操作では、各部分的なソフトマックス結果間の同期された更新操作が必要であり、LLM でのアテンション計算で最大 20% のオーバーヘッドが発生します。
(2) フラット GEMM の計算が十分に活用されていない。
LLM 推論で GEMM を実行する行列の形状は平坦であるため、以前の設計でゼロを埋め込んだ後は計算が十分に活用されず、50% を超えるパフォーマンスが低下します。
(3) 静的データフローによるパフォーマンスの損失。
LLM のカーネルのパフォーマンスは、さまざまな入力データの特徴、ハードウェア構成などに依存します。LLM 推論では、単一の静的なデータフローにより、さまざまな形状の GEMM のパフォーマンスが 50.25% 低下する可能性があります。
主流の LLM とハードウェア バックエンドをサポートする高速 LLM 推論エンジンである FlashDecoding++ を紹介します。
上記の課題に取り組むために、FlashDecoding++ は次のことを独創的に提案します。 (1) 最大値が統一された非同期ソフトマックス。
FlashDecoding++ では、同期を回避するために、さまざまな部分的なソフトマックス計算に統一された最大値手法が導入されています。
(2) ダブルバッファリングによるフラット GEMM 最適化。
FlashDecoding++ は、さまざまな形状のフラット GEMM がさまざまなボトルネックに直面していることを指摘しています。
次に、ダブルバッファリングなどの技術が導入されます。
(3) ハードウェア リソースの適応を伴うヒューリスティック データフロー。
FlashDecoding++ は、入力ダイナミクスを考慮して、さまざまなハードウェア リソースを使用してデータフローをヒューリスティックに最適化します。
FlashDecoding++ の最適化の多様性により、FlashDecoding++ は、Hugging Face 実装と比較して、NVIDIA と AMD GPU の両方で最大 4.86 倍および 2.18 倍の高速化を達成できます。
FlashDecoding++ は、主流の LLM 上の最先端の LLM 推論エンジンと比較して、平均 1.37 倍の高速化も達成します。

要約(オリジナル)

As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.

arxiv情報

著者 Ke Hong,Guohao Dai,Jiaming Xu,Qiuli Mao,Xiuhong Li,Jun Liu,Kangdi Chen,Yuhan Dong,Yu Wang
発行日 2024-01-05 12:41:13+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク