Gated Linear Attention Transformers with Hardware-Efficient Training

要約

線形アテンションを備えたトランスフォーマーは効率的な並列トレーニングを可能にしますが、同時に 2D (行列値) 隠れ状態を持つ RNN として定式化できるため、(出力長に関して) 線形推論の複雑さを享受できます。
RetNet (Sun et al., 2023) や TransNormerLLM (Qin et al., 2023a) などの最近の研究では、加法的 RNN 更新ルールにグローバル減衰項を追加するとパフォーマンスが大幅に向上し、場合によっては、
規模。
この研究では、データ依存のゲート機構を追加することでパフォーマンスがさらに向上することを示します。
私たちは、効率的なトレーニングを可能にする、このゲートされた線形注意層の並列形式を導出します。
ただし、この並列形式の単純で数値的に安定した実装には、数値の安定性のために対数空間での一般化された行列の乗算が必要であるため、標準の行列の乗算用に最適化された最新の GPU ではテンソル コアを利用できません。
私たちは、シーケンス チャンクに対するブロック並列計算を通じてテンソル コアを利用できる並列形式のハードウェア効率の高いバージョンを開発します。
中規模の言語モデリングに関する実験 (15B トークンでトレーニングされた 340M パラメーター モデル、100B トークンでトレーニングされた 1.3B パラメーター モデル) は、ゲート線形アテンション (GLA) Transformer が強力な LLaMA アーキテクチャ Transformer ベースラインに対して競合的に機能することを示しています (Touvron et al
., 2023) と、データ依存の状態遷移メカニズムを備えた最近導入された状態空間モデルである Mamba (Gu & Dao, 2023) も同様です。
トレーニング速度に関しては、Triton ベースの実装は、通常の 2048 トレーニング長設定では CUDA に最適化された FlashAttendant-2 (Dao、2023) と同等のパフォーマンスを発揮しますが、4096 を超える長いシーケンスでトレーニングする場合は FlashAttendant-2 を上回ります。

要約(オリジナル)

Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear (with respect to output length) inference complexity. Recent works such as RetNet (Sun et al., 2023) and TransNormerLLM (Qin et al., 2023a) observe that adding a global decay term to the additive RNN update rule greatly improves performance, sometimes outperforming standard Transformers with softmax attention when trained at scale. In this work we show that adding a data-dependent gating mechanism further improves performance. We derive a parallel form of this gated linear attention layer that enables efficient training. However, a straightforward, numerically stable implementation of this parallel form requires generalized matrix multiplications in log-space for numerical stability, and thus cannot take advantage of tensor cores on modern GPUs which are optimized for standard matrix multiplications. We develop a hardware-efficient version of the parallel form that can still make use of tensor cores through block-parallel computations over sequence chunks. Experiments on moderate-scale language modeling (340M-parameter models trained on 15B tokens, 1.3B-parameter models trained on 100B tokens) show that gated linear attention (GLA) Transformers perform competitively against a strong LLaMA-architecture Transformer baseline (Touvron et al., 2023) as well as Mamba (Gu & Dao, 2023), a recently introduced state-space model with a data-dependent state transition mechanism. For training speed, our Triton-based implementation performs comparably to CUDA-optimized FlashAttention-2 (Dao, 2023) under the regular 2048 training length setting, while outperforming FlashAttention-2 when training on longer sequences beyond 4096.

arxiv情報

著者 Songlin Yang,Bailin Wang,Yikang Shen,Rameswar Panda,Yoon Kim
発行日 2023-12-12 06:04:14+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク