Improving Token-Based World Models with Parallel Observation Prediction

要約

離散シンボルのシーケンスに適用した場合の Transformers の成功をきっかけに、トークンベースのワールド モデル (TBWM) がサンプル効率の高い方法として最近提案されました。
TBWM では、ワールド モデルはエージェント エクスペリエンスを言語に似たトークンのシーケンスとして消費し、各観察がサブシーケンスを構成します。
ただし、想像中は、次の観測をトークンごとに順次生成するため、深刻なボトルネックが発生し、トレーニング時間が長くなり、GPU 使用率が低下し、表現が制限されます。
このボトルネックを解決するために、私たちは新しい並列観測予測 (POP) メカニズムを考案しました。
POP は、強化学習設定に合わせた新しい順方向モードで Retentive Network (RetNet) を拡張します。
当社では、REM (保持環境モデル) という名前の新しい TBWM エージェントに POP を組み込み、以前の TBWM と比較して 15.4 倍高速な想像力を示しています。
REM は、12 時間未満のトレーニングで、Atari 100K ベンチマークの 26 ゲーム中 12 ゲームで超人的なパフォーマンスを達成しました。
コードは \url{https://github.com/leor-c/REM} で入手できます。

要約(オリジナル)

Motivated by the success of Transformers when applied to sequences of discrete symbols, token-based world models (TBWMs) were recently proposed as sample-efficient methods. In TBWMs, the world model consumes agent experience as a language-like sequence of tokens, where each observation constitutes a sub-sequence. However, during imagination, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations. To resolve this bottleneck, we devise a novel Parallel Observation Prediction (POP) mechanism. POP augments a Retentive Network (RetNet) with a novel forward mode tailored to our reinforcement learning setting. We incorporate POP in a novel TBWM agent named REM (Retentive Environment Model), showcasing a 15.4x faster imagination compared to prior TBWMs. REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours. Our code is available at \url{https://github.com/leor-c/REM}.

arxiv情報

著者 Lior Cohen,Kaixin Wang,Bingyi Kang,Shie Mannor
発行日 2024-02-13 15:38:11+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク