Robust low-rank training via approximate orthonormal constraints

要約

モデルやデータサイズの増大に伴い、モデルの性能を維持しつつ、深層学習パイプラインのリソース要求を低減するプルーニング技術を設計するための幅広い取り組みが行われています。推論コストと学習コストの両方を削減するために、ネットワークの重みを表現するために低ランクの行列分解を使用する著名な研究がある。精度は維持できるものの、低ランクの手法は、敵対的な摂動に対するモデルの頑健性を損なう傾向があることが分かっています。ロバスト性をニューラルネットワークの条件数でモデル化することで、このロバスト性の低下は、低ランク重み行列の特異値が爆発的に増加することに起因していると主張する。そこで、ネットワークの重みを低ランク行列の多様体に維持すると同時に、近似的な正規制約を適用する、ロバストな低ランクトレーニングアルゴリズムを導入する。その結果、モデルの精度を落とすことなく、学習と推論のコストを削減し、良好な条件付けを保証することで、より優れた敵対的頑健性を実現する。このことは、広範な数値的証拠と、高性能な低ランクのサブネットワークが存在する場合に、計算された頑健な低ランクネットワークが理想的なフルモデルによく近似することを示す主な近似定理によって示される。

要約(オリジナル)

With the growth of model and data sizes, a broad effort has been made to design pruning techniques that reduce the resource demand of deep learning pipelines, while retaining model performance. In order to reduce both inference and training costs, a prominent line of work uses low-rank matrix factorizations to represent the network weights. Although able to retain accuracy, we observe that low-rank methods tend to compromise model robustness against adversarial perturbations. By modeling robustness in terms of the condition number of the neural network, we argue that this loss of robustness is due to the exploding singular values of the low-rank weight matrices. Thus, we introduce a robust low-rank training algorithm that maintains the network’s weights on the low-rank matrix manifold while simultaneously enforcing approximate orthonormal constraints. The resulting model reduces both training and inference costs while ensuring well-conditioning and thus better adversarial robustness, without compromising model accuracy. This is shown by extensive numerical evidence and by our main approximation theorem that shows the computed robust low-rank network well-approximates the ideal full model, provided a highly performing low-rank sub-network exists.

arxiv情報

著者 Dayana Savostianova,Emanuele Zangrando,Gianluca Ceruti,Francesco Tudisco
発行日 2023-06-02 12:22:35+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, DeepL

カテゴリー: cs.AI, cs.LG, cs.NA, math.NA, stat.ML パーマリンク