Fourier Head: Helping Large Language Models Learn Complex Probability Distributions

要約

大規模な言語モデルの品質が向上するにつれて、それらを使用して非言語トークンをモデル化することへの関心が高まっています。
たとえば、Decision Transformer は、デコーダ専用 LLM を使用して、Atari エージェントの個別のアクション空間にわたる分布をモデル化し、エージェントの意思決定をシーケンス モデリング問題として再キャストします。
ただし、LLM を非言語ドメインに適応させる場合、離散ビン上のソフトマックスがトークンの連続構造と、高品質のトークン生成に必要な潜在的に複雑な分布を捕捉できるかどうかは不明のままです。
フーリエ級数を使用して構築されたニューラル ネットワーク層を導入します。これは、出力をより連続的な構造にしたい場合に、任意の線形層を簡単に置き換えることができます。
当社は、合成データセットだけでなく、大規模な意思決定や時系列予測タスクに対して広範な分析を実行します。
また、この層が高周波ノイズを無視しながらデータから信号をより適切に学習できるという理論的証拠も提供します。
私たちの結果はすべて、基礎となるデータ分布が自然な連続構造を持つシナリオにおける、私たちが提案するフーリエ ヘッドの有効性を裏付けています。
たとえば、フーリエ ヘッドは、Atari Seaquest ゲームでの Decision Transformer エージェントの収益を 46% 向上させ、トレーニング中には見ら​​れなかった 20 のベンチマーク全体で、最先端の時系列基盤モデルの予測パフォーマンスを 3.5% 向上させました。

要約(オリジナル)

As the quality of large language models has improved, there has been increased interest in using them to model non-linguistic tokens. For example, the Decision Transformer recasts agentic decision making as a sequence modeling problem, using a decoder-only LLM to model the distribution over the discrete action space for an Atari agent. However, when adapting LLMs to non-linguistic domains, it remains unclear if softmax over discrete bins captures the continuous structure of the tokens and the potentially complex distributions needed for high quality token generation. We introduce a neural network layer, constructed using Fourier series, which we can easily substitute for any linear layer if we want the outputs to have a more continuous structure. We perform extensive analysis on synthetic datasets, as well as on large-scale decision making and time series forecasting tasks. We also provide theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise. All of our results support the effectiveness of our proposed Fourier head in scenarios where the underlying data distribution has a natural continuous structure. For example, the Fourier head improves a Decision Transformer agent’s returns by 46% on the Atari Seaquest game, and increases a state-of-the-art times series foundation model’s forecasting performance by 3.5% across 20 benchmarks unseen during training.

arxiv情報

著者 Nate Gillman,Daksh Aggarwal,Michael Freeman,Saurabh Singh,Chen Sun
発行日 2024-10-29 17:27:58+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.CL, cs.LG, stat.ML パーマリンク