要約
アクセラレータ (GPU/TPU) 上の生成大規模言語モデル (LLM) を使用した自己回帰デコードは、メモリに依存することが多く、モデル パラメーターを高帯域幅メモリ (HBM) からキャッシュに転送するのにほとんどの時間が費やされます。
一方、最近の研究では、LLM が行/列の上位 $k$ 部分で動作するようにモデルを適切にトレーニングすることにより、フィードフォワード (FFN) 層で大幅なスパース性/冗長性を備えた品質を維持できることが示されています ($k \estimate
0.05$)、モデルパラメータの転送、ひいてはレイテンシを削減する方法を提案しています。
ただし、このスパース性を利用してレイテンシーを改善することは、最上位の行/列の識別がデータに依存しており、通常は完全な行列演算を使用して実行されるため、潜在的な利益が大幅に制限されるという事実によって妨げられます。
これらの問題に対処するために、HiRE (高再現率近似 Top-k 推定) を導入します。
HiRE は 2 つの新しいコンポーネントで構成されます: (i) 高い再現率で上位 $k$ の行/列を安価に予測する圧縮スキームと、その後に予測されたサブセットに限定された完全な計算、および (ii) DA-TOP-$k$:
効率的なマルチデバイス近似上位 $k$ 演算子。
10 億のパラメーター モデルで、HiRE がソフトマックス層とフィードフォワード層の両方に適用され、ほぼ一致する事前トレーニングとダウンストリーム精度を達成し、単一の TPUv5e デバイスで推論レイテンシが $1.47\times$ 高速化されることを実証します。
要約(オリジナル)
Autoregressive decoding with generative Large Language Models (LLMs) on accelerators (GPUs/TPUs) is often memory-bound where most of the time is spent on transferring model parameters from high bandwidth memory (HBM) to cache. On the other hand, recent works show that LLMs can maintain quality with significant sparsity/redundancy in the feedforward (FFN) layers by appropriately training the model to operate on a top-$k$ fraction of rows/columns (where $k \approx 0.05$), there by suggesting a way to reduce the transfer of model parameters, and hence latency. However, exploiting this sparsity for improving latency is hindered by the fact that identifying top rows/columns is data-dependent and is usually performed using full matrix operations, severely limiting potential gains. To address these issues, we introduce HiRE (High Recall Approximate Top-k Estimation). HiRE comprises of two novel components: (i) a compression scheme to cheaply predict top-$k$ rows/columns with high recall, followed by full computation restricted to the predicted subset, and (ii) DA-TOP-$k$: an efficient multi-device approximate top-$k$ operator. We demonstrate that on a one billion parameter model, HiRE applied to both the softmax as well as feedforward layers, achieves almost matching pretraining and downstream accuracy, and speeds up inference latency by $1.47\times$ on a single TPUv5e device.
arxiv情報
著者 | Yashas Samaga B L,Varun Yerram,Chong You,Srinadh Bhojanapalli,Sanjiv Kumar,Prateek Jain,Praneeth Netrapalli |
発行日 | 2024-02-14 18:04:36+00:00 |
arxivサイト | arxiv_id(pdf) |
提供元, 利用サービス
arxiv.jp, Google