FlashDecoding++: Faster Large Language Model Inference on GPUs

要約

大規模言語モデル(LLM)は様々な領域でますます重要性を増している。しかし、LLMの推論を高速化するためには、以下の課題がまだ解決されていない。ソフトマックス演算は、各部分ソフトマックス結果間の同期更新演算を必要とするため、LLMにおける注目計算のオーバーヘッドが20%程度発生する。(2)フラットGEMMの計算の過小利用。LLM推論でGEMMを実行する行列の形状は平坦であるため、従来の設計ではゼロをパディングした後の計算が十分に利用されず、50%以上の性能低下を招く。(3) 静的データフローによる性能低下。LLMのカーネル性能は、様々な入力データの特徴やハードウェア構成などに依存する。単一の静的なデータフローは、LLM推論において異なる形状のGEMMに対して50.25%の性能損失につながる可能性がある。 我々は、主流のLLMとハードウェアバックエンドをサポートする高速LLM推論エンジンFlashDecoding++を発表する。上記の課題に取り組むために、FlashDecoding++は次のような独創的な提案を行っている。FlashDecoding++は、同期を避けるために、異なる部分的なソフトマックス計算に統一された最大値技術を導入する。(2) ダブルバッファリングによるフラットGEMM最適化。FlashDecoding++は、様々な形状のフラットGEMMが様々なボトルネックに直面していることを指摘しています。そこで、二重バッファリングなどの手法を導入する。(3)ハードウェア資源適応による発見的データフロー。FlashDecoding++は、入力のダイナミクスを考慮し、異なるハードウェアリソースを用いてデータフローを発見的に最適化する。FlashDecoding++の最適化の多様性により、FlashDecoding++はHugging Faceの実装と比較して、NVIDIAとAMDの両方のGPUで最大4.86倍と2.18倍の高速化を達成することができます。また、FlashDecoding++は、主流のLLM上で最新のLLM推論エンジンと比較して平均1.37倍の高速化を達成しています。

要約(オリジナル)

As the Large Language Model (LLM) becomes increasingly important in various domains. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ~20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and >50% performance loss after padding zeros in previous designs. (3) Performance loss due to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to a 50.25% performance loss for GEMMs of different shapes in LLM inference. We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back-ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced. (3) Heuristic dataflow with hardware resource adaptation. FlashDecoding++ heuristically optimizes dataflow using different hardware resource considering input dynamics. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86x and 2.18x speedup on both NVIDIA and AMD GPUs compared to Hugging Face implementations. FlashDecoding++ also achieves an average speedup of 1.37x compared to state-of-the-art LLM inference engines on mainstream LLMs.

arxiv情報

著者 Ke Hong,Guohao Dai,Jiaming Xu,Qiuli Mao,Xiuhong Li,Jun Liu,Kangdi Chen,Hanyu Dong,Yu Wang
発行日 2023-11-03 14:59:06+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, DeepL

カテゴリー: cs.CL, cs.LG パーマリンク