Drama: Mamba-Enabled Model-Based Reinforcement Learning Is Sample and Parameter Efficient

要約

モデルベースの強化学習 (RL) は、ほとんどのモデルフリー RL アルゴリズムを悩ませるデータの非効率性に対する解決策を提供します。
ただし、堅牢な世界モデルを学習するには、多くの場合、計算とトレーニングにコストがかかる、複雑で奥深いアーキテクチャが必要になります。
ワールド モデルの中で、ダイナミクス モデルは正確な予測に特に重要であり、それぞれに独自の課題を抱えたさまざまなダイナミクス モデル アーキテクチャが検討されてきました。
現在、リカレント ニューラル ネットワーク (RNN) ベースの世界モデルは、勾配の消失や長期的な依存関係を効果的に把握することが難しいなどの問題に直面しています。
対照的に、トランスフォーマーの使用には、メモリと計算の複雑さの両方が $O(n^2)$ ($n$ がシーケンス長を表す) としてスケールされるセルフ アテンション メカニズムのよく知られた問題があります。
これらの課題に対処するために、私たちは、長期的な依存関係を効果的にキャプチャし、より長いトレーニング シーケンスの効率的な使用を容易にしながら、$O(n)$ のメモリと計算の複雑さを達成する、特に Mamba に基づいた状態空間モデル (SSM) ベースのワールド モデルを提案します。

また、トレーニングの初期段階で不正確なワールド モデルによって引き起こされる準最適性を軽減する新しいサンプリング手法を導入し、前述の手法と組み合わせて、他の最先端のモデルベースの RL アルゴリズムに匹敵する正規化スコアを達成します。
700 万のトレーニング可能なパラメータの世界モデルのみを使用します。
このモデルはアクセス可能で、既製のラップトップでトレーニングできます。
私たちのコードは https://github.com/realwenlongwang/drama.git で入手できます。

要約(オリジナル)

Model-based reinforcement learning (RL) offers a solution to the data inefficiency that plagues most model-free RL algorithms. However, learning a robust world model often demands complex and deep architectures, which are expensive to compute and train. Within the world model, dynamics models are particularly crucial for accurate predictions, and various dynamics-model architectures have been explored, each with its own set of challenges. Currently, recurrent neural network (RNN) based world models face issues such as vanishing gradients and difficulty in capturing long-term dependencies effectively. In contrast, use of transformers suffers from the well-known issues of self-attention mechanisms, where both memory and computational complexity scale as $O(n^2)$, with $n$ representing the sequence length. To address these challenges we propose a state space model (SSM) based world model, specifically based on Mamba, that achieves $O(n)$ memory and computational complexity while effectively capturing long-term dependencies and facilitating the use of longer training sequences efficiently. We also introduce a novel sampling method to mitigate the suboptimality caused by an incorrect world model in the early stages of training, combining it with the aforementioned technique to achieve a normalised score comparable to other state-of-the-art model-based RL algorithms using only a 7 million trainable parameter world model. This model is accessible and can be trained on an off-the-shelf laptop. Our code is available at https://github.com/realwenlongwang/drama.git.

arxiv情報

著者 Wenlong Wang,Ivana Dusparic,Yucheng Shi,Ke Zhang,Vinny Cahill
発行日 2024-10-11 15:10:40+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG, cs.RO パーマリンク