The Mamba in the Llama: Distilling and Accelerating Hybrid Models

要約

Mamba のような線形 RNN アーキテクチャは、有利な展開特性を備えながら、言語モデリングにおいて Transformer モデルと競合できます。
大規模な Transformer モデルのトレーニングに重点を置いているため、これらの事前トレーニング済みモデルをデプロイメント用に変換するという課題を検討します。
学術的な GPU リソースを使用してアテンション層からの線形射影の重みを再利用することで、大規模な Transformer を線形 RNN に蒸留することが実現可能であることを示します。
結果として得られるハイブリッド モデルは、アテンション レイヤーの 4 分の 1 を組み込んでおり、チャット ベンチマークではオリジナルの Transformer に匹敵するパフォーマンスを達成し、チャット ベンチマークと一般ベンチマークの両方で、数兆のトークンを使用してゼロからトレーニングされたオープンソース ハイブリッド Mamba モデルを上回ります。
さらに、Mamba モデルとハイブリッド モデルの推論速度を高速化するハードウェア対応の投機的デコード アルゴリズムを導入します。
全体として、限られた計算リソースで、元のアテンション層の多くを削除し、結果として得られるモデルからより効率的に生成できる方法を示します。
Llama3-8B-Instruct から抽出された最高パフォーマンスのモデルは、AlpacaEval 2 で GPT-4 に対して長さ制御された勝率 29.61、MT-Bench で 7.35 を達成し、最高の命令調整された線形 RNN モデルを上回っています。

要約(オリジナル)

Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

arxiv情報

著者 Junxiong Wang,Daniele Paliotta,Avner May,Alexander M. Rush,Tri Dao
発行日 2024-08-27 17:56:11+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク