Provably learning a multi-head attention layer

要約

マルチヘッド アテンション レイヤーは、トランス アーキテクチャを従来のフィードフォワード モデルとは区別する重要なコンポーネントの 1 つです。
シーケンス長 $k$、注目行列 $\mathbf{\Theta}_1,\ldots,\mathbf{\Theta}_m\in\mathbb{R}^{d\times d}$、射影行列 $\ が与えられるとします。
mathbf{W}_1,\ldots,\mathbf{W}_m\in\mathbb{R}^{d\times d}$、対応するマルチヘッド アテンション レイヤー $F: \mathbb{R}^{k\
回 d}\to \mathbb{R}^{k\times d}$ は、長さ $k$ の $d$ 次元トークンのシーケンス $\mathbf{X}\in\mathbb{R}^{k\times を変換します
d}$ via $F(\mathbf{X}) \triangleq \sum^m_{i=1} \mathrm{softmax}(\mathbf{X}\mathbf{\Theta}_i\mathbf{X}^\top
)\mathbf{X}\mathbf{W}_i$。
この研究では、ランダムな例からマルチヘッド アテンション層を証明可能に学習する研究を開始し、この問題に対する最初の自明でない上限と下限を与えます。 – $\{\mathbf{W}_i, \mathbf{\Theta を提供
}_i\}$ が特定の非縮退条件を満たしている場合、$\{ から均一に抽出されたランダムなラベル付きサンプルを与えられた場合に $F$ を小さな誤差まで学習する $(dk)^{O(m^3)}$ 時間アルゴリズムを与えます。
\pm 1\}^{k\times d}$。
– 最悪の場合、$m$ への指数関数的な依存が避けられないことを示す計算の下限を証明します。
私たちは、大規模な言語モデルにおけるトークンの離散的な性質を模倣するためにブール値 $\mathbf{X}$ に焦点を当てていますが、私たちの技術は当然、標準的な連続設定にも拡張されています。
ガウス。
私たちのアルゴリズムは、例を使用して未知のパラメーターを含む凸体を彫刻することに重点を置いており、主にガウス分布の代数および回転不変特性を利用する、フィードフォワード ネットワークを学習するための既存の証明可能なアルゴリズムとは大きく異なります。
対照的に、私たちの分析は主に入力分布とその「スライス」のさまざまな上下の裾の境界に依存しているため、より柔軟です。

要約(オリジナル)

The multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models. Given a sequence length $k$, attention matrices $\mathbf{\Theta}_1,\ldots,\mathbf{\Theta}_m\in\mathbb{R}^{d\times d}$, and projection matrices $\mathbf{W}_1,\ldots,\mathbf{W}_m\in\mathbb{R}^{d\times d}$, the corresponding multi-head attention layer $F: \mathbb{R}^{k\times d}\to \mathbb{R}^{k\times d}$ transforms length-$k$ sequences of $d$-dimensional tokens $\mathbf{X}\in\mathbb{R}^{k\times d}$ via $F(\mathbf{X}) \triangleq \sum^m_{i=1} \mathrm{softmax}(\mathbf{X}\mathbf{\Theta}_i\mathbf{X}^\top)\mathbf{X}\mathbf{W}_i$. In this work, we initiate the study of provably learning a multi-head attention layer from random examples and give the first nontrivial upper and lower bounds for this problem: – Provided $\{\mathbf{W}_i, \mathbf{\Theta}_i\}$ satisfy certain non-degeneracy conditions, we give a $(dk)^{O(m^3)}$-time algorithm that learns $F$ to small error given random labeled examples drawn uniformly from $\{\pm 1\}^{k\times d}$. – We prove computational lower bounds showing that in the worst case, exponential dependence on $m$ is unavoidable. We focus on Boolean $\mathbf{X}$ to mimic the discrete nature of tokens in large language models, though our techniques naturally extend to standard continuous settings, e.g. Gaussian. Our algorithm, which is centered around using examples to sculpt a convex body containing the unknown parameters, is a significant departure from existing provable algorithms for learning feedforward networks, which predominantly exploit algebraic and rotation invariance properties of the Gaussian distribution. In contrast, our analysis is more flexible as it primarily relies on various upper and lower tail bounds for the input distribution and ‘slices’ thereof.

arxiv情報

著者 Sitan Chen,Yuanzhi Li
発行日 2024-02-06 15:39:09+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.DS, cs.LG, stat.ML パーマリンク