Correct-N-Contrast: A Contrastive Approach for Improving Robustness to Spurious Correlations

要約

スプリアス相関は、堅牢な機械学習にとって大きな課題となります。
経験的リスク最小化 (ERM) でトレーニングされたモデルは、クラス ラベルと偽の属性間の相関関係に依存することを学習し、これらの相関関係のないデータ グループのパフォーマンスが低下する可能性があります。
これは、偽の属性ラベルが利用できない場合に対処するのが特に困難です。
属性ラベルをトレーニングせずに擬似相関データの最悪グループのパフォーマンスを改善するために、擬似相関に強い表現を直接学習する対照的なアプローチである Correct-N-Contrast (CNC) を提案します。
ERM モデルは優れたスプリアス属性予測子となる可能性があるため、CNC は、(1) トレーニングされた ERM モデルの出力を使用して、クラスは同じだが異なるスプリアス特徴を持つサンプルを識別し、(2) 対照学習で堅牢なモデルをトレーニングして、類似した表現を学習します。
同クラスのサンプル。
CNC をサポートするために、CNC が最小化を目指す最悪グループ誤差と表現アライメント損失との間に新しい関係を導入します。
我々は、最悪グループのエラーがアライメント損失と密接に関係していることを経験的に観察し、クラス全体にわたるアライメント損失がクラスの最悪グループと平均エラーギャップの上限を高めるのに役立つことを証明しています。
一般的なベンチマークでは、CNC はアライメント損失を大幅に削減し、平均絶対リフト 3.6% による最先端の最悪グループ精度を達成します。
CNC は、グループ ラベルを必要とするオラクル手法とも競合します。

要約(オリジナル)

Spurious correlations pose a major challenge for robust machine learning. Models trained with empirical risk minimization (ERM) may learn to rely on correlations between class labels and spurious attributes, leading to poor performance on data groups without these correlations. This is particularly challenging to address when spurious attribute labels are unavailable. To improve worst-group performance on spuriously correlated data without training attribute labels, we propose Correct-N-Contrast (CNC), a contrastive approach to directly learn representations robust to spurious correlations. As ERM models can be good spurious attribute predictors, CNC works by (1) using a trained ERM model’s outputs to identify samples with the same class but dissimilar spurious features, and (2) training a robust model with contrastive learning to learn similar representations for same-class samples. To support CNC, we introduce new connections between worst-group error and a representation alignment loss that CNC aims to minimize. We empirically observe that worst-group error closely tracks with alignment loss, and prove that the alignment loss over a class helps upper-bound the class’s worst-group vs. average error gap. On popular benchmarks, CNC reduces alignment loss drastically, and achieves state-of-the-art worst-group accuracy by 3.6% average absolute lift. CNC is also competitive with oracle methods that require group labels.

arxiv情報

著者 Michael Zhang,Nimit S. Sohoni,Hongyang R. Zhang,Chelsea Finn,Christopher Ré
発行日 2024-12-11 17:06:21+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.LG パーマリンク