FlashRNN: Optimizing Traditional RNNs on Modern Hardware

要約

Transformer やその他のシーケンス並列化可能なニューラル ネットワーク アーキテクチャは、シーケンス モデリングにおける現在の最先端のように見えますが、特に状態追跡機能が不足しています。
これらは、時系列タスクと論理的推論にとって重要です。
LSTM や GRU などの従来の RNN、および sLSTM などの最新のバリアントには、厳密な逐次処理を犠牲にしてこれらの機能があります。
これは多くの場合、強力な制限であると見なされますが、Triton および CUDA のハードウェア最適化 FlashRNN を使用して、最新の GPU のレジスタ レベルまでカーネルを最適化することで、これらのネットワークがどれほど高速になるかを示します。
従来の RNN を、Transformers のヘッドワイズ処理と同様に、より小さい隠れ状態の複数の RNN を並列処理する並列化バリアントで拡張します。
さまざまな GPU バリアントでの柔軟性を可能にするために、ハードウェア内部キャッシュ サイズ、メモリ、およびコンピューティング処理のための新しい最適化フレームワークを導入します。
割り算の概念を含む多面体のような制約を使用して、設定内のハードウェアをモデル化します。
これにより、一般的な整数制約満足問題 (整数 CSP) の ConstrINT ライブラリでの解決プロセスが高速化されます。
私たちのカーネルは、通常の PyTorch 実装と比較して 50 倍の高速化を達成でき、Triton 実装と比較して 40 倍の隠蔽サイズを許可できることを示します。
私たちのオープンソース カーネルと最適化ライブラリは、状態追跡対応 RNN とシーケンス モデリングの方向の研究を促進するためにここでリリースされています: \url{https://github.com/NX-AI/flashrnn}

要約(オリジナル)

While Transformers and other sequence-parallelizable neural network architectures seem like the current state of the art in sequence modeling, they specifically lack state-tracking capabilities. These are important for time-series tasks and logical reasoning. Traditional RNNs like LSTMs and GRUs, as well as modern variants like sLSTM do have these capabilities at the cost of strictly sequential processing. While this is often seen as a strong limitation, we show how fast these networks can get with our hardware-optimization FlashRNN in Triton and CUDA, optimizing kernels to the register level on modern GPUs. We extend traditional RNNs with a parallelization variant that processes multiple RNNs of smaller hidden state in parallel, similar to the head-wise processing in Transformers. To enable flexibility on different GPU variants, we introduce a new optimization framework for hardware-internal cache sizes, memory and compute handling. It models the hardware in a setting using polyhedral-like constraints, including the notion of divisibility. This speeds up the solution process in our ConstrINT library for general integer constraint satisfaction problems (integer CSPs). We show that our kernels can achieve 50x speed-ups over a vanilla PyTorch implementation and allow 40x larger hidden sizes compared to our Triton implementation. Our open-source kernels and the optimization library are released here to boost research in the direction of state-tracking enabled RNNs and sequence modeling: \url{https://github.com/NX-AI/flashrnn}

arxiv情報

著者 Korbinian Pöppel,Maximilian Beck,Sepp Hochreiter
発行日 2024-12-10 18:50:37+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク