Scalify: scale propagation for efficient low-precision LLM training

要約

float8 などの低精度形式は、大規模な言語モデルのトレーニングと推論の計算効率を向上させるために、機械学習アクセラレーション ハードウェアに導入されています。
それにもかかわらず、より高精度のトレーニング精度に匹敵するために必要な複雑で、場合によっては脆弱な技術により、ML コミュニティでの採用は遅れています。
この研究では、既存のテンソル スケーリング手法を一般化および形式化する、計算グラフのエンドツーエンドのスケール伝播パラダイムである Scalify を紹介します。
実験結果は、Scalify がすぐに使用できる float8 行列の乗算と勾配表現、および float16 オプティマイザー状態ストレージをサポートしていることを示しています。
Scalify の JAX 実装は、https://github.com/graphcore-research/jax-scalify でオープンソース化されています。

要約(オリジナル)

Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify

arxiv情報

著者 Paul Balança,Sam Hosegood,Carlo Luschi,Andrew Fitzgibbon
発行日 2024-07-24 15:26:01+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: 68T07, cs.LG, I.2.7 パーマリンク