要約
トランスフォーマーベースの言語モデルに対するリカレント ニューラル ネットワーク (RNN) の重要な利点の 1 つは、シーケンスの長さに関する線形計算の複雑さです。これにより、推論中の長いシーケンスの処理が大幅に高速になります。
しかし、ほとんどの公的に利用可能な RNN (Mamba や RWKV など) は 10,000 トークン未満のシーケンスでトレーニングされており、より長いコンテキストでの有効性は今のところほとんど満足のいくものではありません。
このペーパーでは、RNN の長いコンテキストを処理できない原因を調査し、重要な緩和策を提案します。
最先端の RNN を長いコンテキストに適用する場合の 2 つの実際的な懸念事項を検討します。(1) トレーニング長よりも長い入力を外挿できないこと、および (2) メモリ容量の上限です。
最初の懸念事項に対処するには、まず、トレーニング中に発生しないシーケンス長で重大なパフォーマンス低下を引き起こす現象である *状態崩壊* (SC) を調査します。
制御された実験では、トレーニングの長さに対して反復状態が過剰にパラメータ化されているため、これは過剰適合であると考えられます。
2 番目の懸念事項については、言語モデリングとパスキー取得における反復状態容量を経験的に推定するために、長い文書で一連の Mamba-2 モデルをトレーニングします。
次に、Mamba-2 の長さの一般化性を改善するために 3 つの SC 緩和方法が提案され、モデルが SC なしで 100 万を超えるトークンを処理できるようになります。
また、パスキー取得における反復状態容量が状態サイズに指数関数的にスケールすることもわかり、256K コンテキスト長でほぼ完璧なパスキー取得精度で Mamba-2 370M を経験的にトレーニングしました。
これは、RNN ベースのロングコンテキスト モデリングの有望な将来を示唆しています。
要約(オリジナル)
One essential advantage of recurrent neural networks (RNNs) over transformer-based language models is their linear computational complexity concerning the sequence length, which makes them much faster in handling long sequences during inference. However, most publicly available RNNs (e.g., Mamba and RWKV) are trained on sequences with less than 10K tokens, and their effectiveness in longer contexts remains largely unsatisfying so far. In this paper, we study the cause of the inability to process long context for RNNs and suggest critical mitigations. We examine two practical concerns when applying state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to inputs longer than the training length and (2) the upper bound of memory capacity. Addressing the first concern, we first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training. With controlled experiments, we attribute this to overfitting due to the recurrent state being overparameterized for the training length. For the second concern, we train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval. Then, three SC mitigation methods are proposed to improve Mamba-2’s length generalizability, allowing the model to process more than 1M tokens without SC. We also find that the recurrent state capacity in passkey retrieval scales exponentially to the state size, and we empirically train a Mamba-2 370M with near-perfect passkey retrieval accuracy on 256K context length. This suggests a promising future for RNN-based long-context modeling.
arxiv情報
著者 | Yingfa Chen,Xinrong Zhang,Shengding Hu,Xu Han,Zhiyuan Liu,Maosong Sun |
発行日 | 2024-10-09 17:54:28+00:00 |
arxivサイト | arxiv_id(pdf) |
提供元, 利用サービス
arxiv.jp, Google