Conv-Basis: A New Paradigm for Efficient Attention Inference and Gradient Computation in Transformers

要約

大規模言語モデル (LLM) は世界を大きく変えました。
自己注意メカニズムは、LLM におけるトランスフォーマーの成功の鍵です。
ただし、長さ $n$ の入力シーケンスに対する二次計算コスト $O(n^2)$ は、より長いコンテキストでのさらなる改善とスケーラビリティにとって悪名高い障害となります。
この研究では、アテンション行列の畳み込みのような構造を利用して、コンボリューション行列を使用したアテンション計算の効率的な近似方法を開発します。
ランク基底に「似た」 $\mathsf{conv}$ 基底系を提案し、この基底系ではどんな下三角 (注意) 行列も常に $k$ 構造化畳み込み行列の和として分解できることを示します。
次に、アテンション行列を $k$ 畳み込み行列にすばやく分解するアルゴリズムを設計します。
高速フーリエ変換 (FFT) のおかげで、アテンション {\it inference} は $O(knd \log n)$ 時間で計算できます。ここで $d$ は隠れ次元です。
実際には、$ d \ll n$、つまり、Gemma の $d=3,072$ および $n=1,000,000$ があります。
したがって、$kd = n^{o(1)}$ の場合、アルゴリズムはほぼ線形の時間を達成します (つまり、$n^{1+o(1)}$)。
さらに、アテンション {\it training forward} と {\it backward gradient} も $n^{1+o(1)}$ で計算できます。
私たちのアプローチでは、$n \times n$ 注目行列の明示的な計算を回避できるため、二次計算の複雑さが大幅に軽減される可能性があります。
さらに、私たちのアルゴリズムはあらゆる入力行列に対して機能します。
この研究は、トランスフォーマーのアテンション計算を加速し、より長いコンテキストへのアプリケーションを可能にするための新しいパラダイムを提供します。

要約(オリジナル)

Large Language Models (LLMs) have profoundly changed the world. Their self-attention mechanism is the key to the success of transformers in LLMs. However, the quadratic computational cost $O(n^2)$ to the length $n$ input sequence is the notorious obstacle for further improvement and scalability in the longer context. In this work, we leverage the convolution-like structure of attention matrices to develop an efficient approximation method for attention computation using convolution matrices. We propose a $\mathsf{conv}$ basis system, ‘similar’ to the rank basis, and show that any lower triangular (attention) matrix can always be decomposed as a sum of $k$ structured convolution matrices in this basis system. We then design an algorithm to quickly decompose the attention matrix into $k$ convolution matrices. Thanks to Fast Fourier Transforms (FFT), the attention {\it inference} can be computed in $O(knd \log n)$ time, where $d$ is the hidden dimension. In practice, we have $ d \ll n$, i.e., $d=3,072$ and $n=1,000,000$ for Gemma. Thus, when $kd = n^{o(1)}$, our algorithm achieve almost linear time, i.e., $n^{1+o(1)}$. Furthermore, the attention {\it training forward} and {\it backward gradient} can be computed in $n^{1+o(1)}$ as well. Our approach can avoid explicitly computing the $n \times n$ attention matrix, which may largely alleviate the quadratic computational complexity. Furthermore, our algorithm works on any input matrices. This work provides a new paradigm for accelerating attention computation in transformers to enable their application to longer contexts.

arxiv情報

著者 Jiuxiang Gu,Yingyu Liang,Heshan Liu,Zhenmei Shi,Zhao Song,Junze Yin
発行日 2024-05-08 17:11:38+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク