A stochastic optimization approach to train non-linear neural networks with regularization of higher-order total variation

要約

ディープニューラルネットワークを含む表現力の高いパラメトリックモデルは、複雑な概念をモデル化するのに有利であるが、このような非線形性の高いモデルの学習は、悪名高いオーバーフィッティングの高いリスクをもたらすことが知られている。この問題に対処するため、本研究では$k$次全変動($k$-TV)正則化を検討する。$k$-TV正則化は、学習するパラメトリックモデルの$k$次微分の二乗積分として定義される。一般的なパラメトリックモデルに適用される$k$-TV項は積分により計算が困難であるが、本研究では、明示的な数値積分を行うことなく、$k$-TV正則化により一般モデルを効率的に訓練できる確率最適化アルゴリズムを提供する。提案手法は、単純な確率的勾配降下アルゴリズムと自動微分だけで実装できるため、構造が任意なディープニューラルネットワークの学習にも適用できる。我々の数値実験により、$K$-TV項を用いて訓練したニューラルネットワークは、従来のパラメータ正則化を用いたニューラルネットワークよりも“弾力的”であることが実証された。また、提案アルゴリズムは物理情報を用いたニューラルネットワークの訓練(PINN)にも拡張可能である。

要約(オリジナル)

While highly expressive parametric models including deep neural networks have an advantage to model complicated concepts, training such highly non-linear models is known to yield a high risk of notorious overfitting. To address this issue, this study considers a $k$th order total variation ($k$-TV) regularization, which is defined as the squared integral of the $k$th order derivative of the parametric models to be trained; penalizing the $k$-TV is expected to yield a smoother function, which is expected to avoid overfitting. While the $k$-TV terms applied to general parametric models are computationally intractable due to the integration, this study provides a stochastic optimization algorithm, that can efficiently train general models with the $k$-TV regularization without conducting explicit numerical integration. The proposed approach can be applied to the training of even deep neural networks whose structure is arbitrary, as it can be implemented by only a simple stochastic gradient descent algorithm and automatic differentiation. Our numerical experiments demonstrate that the neural networks trained with the $K$-TV terms are more “resilient” than those with the conventional parameter regularization. The proposed algorithm also can be extended to the physics-informed training of neural networks (PINNs).

arxiv情報

著者 Akifumi Okuno
発行日 2023-08-04 12:57:13+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, DeepL

カテゴリー: cs.AI, cs.LG, stat.CO, stat.ME, stat.ML パーマリンク