Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues

要約

Mamba、RWKV、GLA、mLSTM、DeltaNet などの線形リカレント ニューラル ネットワーク (LRNN) は、大規模言語モデリングにおける Transformer の効率的な代替手段として登場し、シーケンス長による線形スケーリングとトレーニング効率の向上を提供します。
ただし、LRNN は状態追跡の実行に苦労し、コード評価やチェスのゲームの追跡などのタスクのパフォーマンスを損なう可能性があります。
偶数パリティは、LSTM のような非線形 RNN が効果的に処理する最も単純な状態追跡タスクですが、現在の LRNN では解決できません。
最近、サロフら。
(2024) は、Mamba のような LRNN がパリティを解決できないのは、対角状態遷移行列の値の範囲を $[0, 1]$ に制限していることに起因しており、負の値を組み込むことでこの問題を解決できることを実証しました。
この結果を非対角 LRNN に拡張します。非対角 LRNN は、最近 DeltaNet などのモデルで有望であることが示されています。
正の固有値のみを持つ状態遷移行列を持つ有限精度 LRNN ではパリティを解くことができませんが、$3$ を法としてカウントするには複素数の固有値が必要であることを証明します。
特に、LRNN の状態遷移行列が $[-1, 1]$ の範囲の固有値を持つ恒等行列からベクトル外積行列を引いた積である場合、LRNN は任意の正規言語を学習できることも証明しています。
私たちの経験的結果は、Mamba や DeltaNet などのモデルの固有値範囲を拡張して負の値を含めることで、パリティを解決できるだけでなく、状態追跡タスクのパフォーマンスも一貫して向上させることを確認しています。
さらに、言語モデリング用に拡張された固有値範囲を使用して LRNN を事前トレーニングすると、コードと数学データで期待を示しながら、同等のパフォーマンスと安定性を達成できます。
私たちの研究により、最新の LRNN の表現力が強化され、トレーニングや推論のコストを変えることなく、その適用可能性が広がります。

要約(オリジナル)

Linear Recurrent Neural Networks (LRNNs) such as Mamba, RWKV, GLA, mLSTM, and DeltaNet have emerged as efficient alternatives to Transformers in large language modeling, offering linear scaling with sequence length and improved training efficiency. However, LRNNs struggle to perform state-tracking which may impair performance in tasks such as code evaluation or tracking a chess game. Even parity, the simplest state-tracking task, which non-linear RNNs like LSTM handle effectively, cannot be solved by current LRNNs. Recently, Sarrof et al. (2024) demonstrated that the failure of LRNNs like Mamba to solve parity stems from restricting the value range of their diagonal state-transition matrices to $[0, 1]$ and that incorporating negative values can resolve this issue. We extend this result to non-diagonal LRNNs, which have recently shown promise in models such as DeltaNet. We prove that finite precision LRNNs with state-transition matrices having only positive eigenvalues cannot solve parity, while complex eigenvalues are needed to count modulo $3$. Notably, we also prove that LRNNs can learn any regular language when their state-transition matrices are products of identity minus vector outer product matrices, each with eigenvalues in the range $[-1, 1]$. Our empirical results confirm that extending the eigenvalue range of models like Mamba and DeltaNet to include negative values not only enables them to solve parity but consistently improves their performance on state-tracking tasks. Furthermore, pre-training LRNNs with an extended eigenvalue range for language modeling achieves comparable performance and stability while showing promise on code and math data. Our work enhances the expressivity of modern LRNNs, broadening their applicability without changing the cost of training or inference.

arxiv情報

著者 Riccardo Grazzi,Julien Siems,Jörg K. H. Franke,Arber Zela,Frank Hutter,Massimiliano Pontil
発行日 2024-11-19 14:35:38+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.FL, cs.LG パーマリンク