How Feature Learning Can Improve Neural Scaling Laws

要約

私たちは、カーネルの限界を超えたニューラル スケーリング則の解決可能なモデルを開発します。
このモデルの理論分析では、モデルのサイズ、トレーニング時間、利用可能なデータの総量に応じてパフォーマンスがどのように変化するかを示します。
さまざまなタスクの難易度に対応する 3 つのスケーリング方式を特定します (難しいタスク、簡単なタスク、超簡単なタスク)。
初期の無限幅ニューラル タンジェント カーネル (NTK) によって定義される再現カーネル ヒルベルト空間 (RKHS) 内にある簡単および超簡単のターゲット関数の場合、スケーリング指数は特徴学習モデルとカーネル レジーム モデルの間で変化しません。
初期の NTK の RKHS の外側にあるものとして定義されるハード タスクについては、特徴学習によってトレーニング時間とコンピューティングのスケーリングが向上し、ハード タスクの指数がほぼ 2 倍になることが分析と経験の両方で実証されています。
これにより、特徴学習領域でパラメーターとトレーニング時間をスケーリングするための最適な戦略を別の計算に導きます。
私たちは、円上のべき乗則フーリエスペクトルを使用して関数をフィッティングする非線形MLPの実験とビジョンタスクを学習するCNNの実験により、特徴学習は難しいタスクのスケーリング則を改善するが、簡単なタスクや超簡単なタスクでは改善しないという発見を裏付けています。

要約(オリジナル)

We develop a solvable model of neural scaling laws beyond the kernel limit. Theoretical analysis of this model shows how performance scales with model size, training time, and the total amount of available data. We identify three scaling regimes corresponding to varying task difficulties: hard, easy, and super easy tasks. For easy and super-easy target functions, which lie in the reproducing kernel Hilbert space (RKHS) defined by the initial infinite-width Neural Tangent Kernel (NTK), the scaling exponents remain unchanged between feature learning and kernel regime models. For hard tasks, defined as those outside the RKHS of the initial NTK, we demonstrate both analytically and empirically that feature learning can improve scaling with training time and compute, nearly doubling the exponent for hard tasks. This leads to a different compute optimal strategy to scale parameters and training time in the feature learning regime. We support our finding that feature learning improves the scaling law for hard tasks but not for easy and super-easy tasks with experiments of nonlinear MLPs fitting functions with power-law Fourier spectra on the circle and CNNs learning vision tasks.

arxiv情報

著者 Blake Bordelon,Alexander Atanasov,Cengiz Pehlevan
発行日 2024-09-26 14:05:32+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cond-mat.dis-nn, cs.LG, stat.ML パーマリンク