Learning Differentiable Surrogate Losses for Structured Prediction

要約

構造化予測には、単純なスカラー値ではなく複雑な構造を予測する方法を学習することが含まれます。
主な課題は出力空間の非ユークリッドの性質から生じており、一般に問題の定式化を緩和する必要があります。
サロゲート メソッドは、カーネルに起因する損失、またはより一般的には、暗黙的な損失埋め込みを許可する損失関数に基づいて構築され、元の問題を回帰タスクに変換し、その後に復号化ステップが続きます。
ただし、複雑な構造を持つオブジェクトの実効損失を設計することには大きな課題があり、多くの場合、ドメイン固有の専門知識が必要です。
この研究では、教師あり代理回帰問題に取り組む前に、ニューラル ネットワークによってパラメータ化された構造化損失関数が、対照学習を通じて出力トレーニング データから直接学習される新しいフレームワークを導入します。
その結果、微分可能な損失により、サロゲート空間の有限次元によるニューラル ネットワークの学習が可能になるだけでなく、勾配降下法に基づく復号戦略を介して出力データの新しい構造の予測も可能になります。
教師ありグラフ予測問題に関する数値実験では、私たちのアプローチが、事前定義されたカーネルに基づく方法と同等かそれ以上のパフォーマンスを達成することが示されています。

要約(オリジナル)

Structured prediction involves learning to predict complex structures rather than simple scalar values. The main challenge arises from the non-Euclidean nature of the output space, which generally requires relaxing the problem formulation. Surrogate methods build on kernel-induced losses or more generally, loss functions admitting an Implicit Loss Embedding, and convert the original problem into a regression task followed by a decoding step. However, designing effective losses for objects with complex structures presents significant challenges and often requires domain-specific expertise. In this work, we introduce a novel framework in which a structured loss function, parameterized by neural networks, is learned directly from output training data through Contrastive Learning, prior to addressing the supervised surrogate regression problem. As a result, the differentiable loss not only enables the learning of neural networks due to the finite dimension of the surrogate space but also allows for the prediction of new structures of the output data via a decoding strategy based on gradient descent. Numerical experiments on supervised graph prediction problems show that our approach achieves similar or even better performance than methods based on a pre-defined kernel.

arxiv情報

著者 Junjie Yang,Matthieu Labeau,Florence d’Alché-Buc
発行日 2024-11-18 16:07:47+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.LG, stat.ML パーマリンク