LoLCATs: On Low-Rank Linearizing of Large Language Models

要約

最近の研究では、大規模言語モデル (LLM) を線形化できることが示されています。これは、一般的な Transformer ベースの LLM の 2 次アテンションを、線形アテンションなどの二次アテンションと交換することで、高価な事前トレーニング コストを回避できることを示しています。
ただし、LLM の線形化はモデルの品質を大幅に低下させることが多く、依然として数十億トークンにわたるトレーニングが必要であり、依然として小規模な 13 億から 70 億 LLM に制限されています。
そこで、我々は、メモリとコンピューティングを桁違いに削減して LLM 線形化の品質を向上させるシンプルな 2 ステップの方法である、アテンション転送による低ランク線形変換 (LoLCAT) を提案します。
これらの手順は 2 つの調査結果に基づいています。
まず、出力 MSE 損失 (「注意の伝達」) に対応するソフトマックスの対応物と一致するように線形アテンションをトレーニングするだけで、LLM のソフトマックス アテンションを非常に近似した線形アテンションに置き換えることができます。
これにより、近似誤差を調整し、低ランク適応 (LoRA) だけで LLM 品質を回復することが可能になります。
LoLCAT は、線形化の品質、トレーニング効率、およびスケーラビリティを大幅に向上させます。
線形化の品質ギャップを大幅に削減し、Llama 3 8B および Mistral 7B v0.1 から最先端の二次二次 LLM を生成し、5 ショット MMLU で 20 ポイント以上の改善につながりました。
さらに、LoLCATs は、過去のメソッドのモデル パラメーターの 0.2% とトレーニング トークンの 0.4% のみを使用してこれを実行します。
最後に、LoLCAT を適用して、最初の線形化された 70B および 405B LLM (以前の作業より 50 倍大きい) を作成します。
同じコンピューティング バジェットの下で以前のアプローチと比較した場合、LoLCAT は線形化の品質を大幅に向上させ、線形化された LLM と元の Llama 3.1 70B および 405B LLM の間のギャップを 5 ショット MMLU で 77.8% および 78.1% 縮めます。

要約(オリジナル)

Recent works show we can linearize large language models (LLMs) — swapping the quadratic attentions of popular Transformer-based LLMs with subquadratic analogs, such as linear attention — avoiding the expensive pretraining costs. However, linearizing LLMs often significantly degrades model quality, still requires training over billions of tokens, and remains limited to smaller 1.3B to 7B LLMs. We thus propose Low-rank Linear Conversion via Attention Transfer (LoLCATs), a simple two-step method that improves LLM linearizing quality with orders of magnitudes less memory and compute. We base these steps on two findings. First, we can replace an LLM’s softmax attentions with closely-approximating linear attentions, simply by training the linear attentions to match their softmax counterparts with an output MSE loss (‘attention transfer’). Then, this enables adjusting for approximation errors and recovering LLM quality simply with low-rank adaptation (LoRA). LoLCATs significantly improves linearizing quality, training efficiency, and scalability. We significantly reduce the linearizing quality gap and produce state-of-the-art subquadratic LLMs from Llama 3 8B and Mistral 7B v0.1, leading to 20+ points of improvement on 5-shot MMLU. Furthermore, LoLCATs does so with only 0.2% of past methods’ model parameters and 0.4% of their training tokens. Finally, we apply LoLCATs to create the first linearized 70B and 405B LLMs (50x larger than prior work). When compared with prior approaches under the same compute budgets, LoLCATs significantly improves linearizing quality, closing the gap between linearized and original Llama 3.1 70B and 405B LLMs by 77.8% and 78.1% on 5-shot MMLU.

arxiv情報

著者 Michael Zhang,Simran Arora,Rahul Chalamala,Alan Wu,Benjamin Spector,Aaryan Singhal,Krithik Ramesh,Christopher Ré
発行日 2024-10-25 17:59:04+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.CL, cs.LG, stat.ML パーマリンク