Trained Transformers Learn Linear Models In-Context

要約

トランスフォーマーなどの注意ベースのニューラル ネットワークは、コンテキスト内学習 (ICL) を示す驚くべき能力を実証しています。目に見えないタスクからトークンの短いプロンプト シーケンスが与えられると、パラメーターを指定せずに関連するトークンごとの予測と次のトークンの予測を定式化できます。
更新情報。
一連のラベル付きトレーニング データとラベルなしテスト データをプロンプトとして埋め込むことで、トランスフォーマーが教師あり学習アルゴリズムのように動作できるようになります。
実際、最近の研究では、線形回帰問題のランダムなインスタンスに対して変換器アーキテクチャをトレーニングすると、これらのモデルの予測が通常の最小二乗の予測を模倣することが示されました。
この現象の根底にあるメカニズムを理解するために、線形回帰タスクの勾配フローによってトレーニングされた単一の線形自己注意層を備えた変圧器における ICL のダイナミクスを調査します。
非凸性にもかかわらず、適切なランダム初期化を伴う勾配流が目的関数の大域的最小値を見つけることを示します。
このグローバル最小値では、新しい予測タスクからのラベル付きサンプルのテスト プロンプトが与えられると、トランスフォーマーは、テスト プロンプトの分布に対して最良の線形予測子と競合する予測誤差を達成します。
さらに、さまざまな分布シフトに対する学習済みトランスフォーマーの堅牢性を特徴付け、多くのシフトは許容されるが、プロンプトの共変量分布のシフトは許容されないことを示します。
これを動機として、共変量分布がプロンプト間で異なる可能性がある一般化された ICL 設定を検討します。
この設定では、勾配流は大域的最小値を見つけることに成功しますが、訓練された変換器は、緩やかな共変量シフトの下では依然として脆弱であることを示します。
我々は、共変量シフトの下でより堅牢であることを示す大規模な非線形変換器アーキテクチャに関する実験でこの発見を補完します。

要約(オリジナル)

Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models’ predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts. We complement this finding with experiments on large, nonlinear transformer architectures which we show are more robust under covariate shifts.

arxiv情報

著者 Ruiqi Zhang,Spencer Frei,Peter L. Bartlett
発行日 2023-08-11 02:04:46+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.CL, cs.LG, stat.ML パーマリンク