Taming the Sigmoid Bottleneck: Provably Argmaxable Sparse Multi-Label Classification

要約

シグモイド出力層は、任意の入力に複数のラベルを割り当てることができるマルチラベル分類 (MLC) タスクで広く使用されています。
多くの実際の MLC タスクでは、使用可能なラベルの数は数千に達し、多くの場合、入力特徴の数を超えて、低ランクの出力層が生成されます。
マルチクラス分類では、このような低ランクの出力層がボトルネックとなり、argmaxable クラス、つまりどの入力に対しても予測できないクラスが生成される可能性があることが知られています。
この論文では、MLC タスクの場合、同様のシグモイド ボトルネックにより、引数最大化できないラベルの組み合わせが指数関数的に増加することを示します。
これらの引数最大化できない出力を検出する方法を説明し、広く使用されている 3 つの MLC データセットでその出力の存在を実証します。
次に、離散フーリエ変換 (DFT) 出力層を導入することで、実際にそれらを防止できることを示します。これにより、$k$ までのアクティブ ラベルを持つすべてのスパース ラベルの組み合わせが argmaxable であることが保証されます。
私たちの DFT 層はトレーニングが高速でパラメータ効率が高く、使用するトレーニング可能なパラメータを最大 50% 減らしながらシグモイド層の F1@k スコアと一致します。
私たちのコードは https://github.com/andreasgrv/sigmoid-bottleneck で公開されています。

要約(オリジナル)

Sigmoid output layers are widely used in multi-label classification (MLC) tasks, in which multiple labels can be assigned to any input. In many practical MLC tasks, the number of possible labels is in the thousands, often exceeding the number of input features and resulting in a low-rank output layer. In multi-class classification, it is known that such a low-rank output layer is a bottleneck that can result in unargmaxable classes: classes which cannot be predicted for any input. In this paper, we show that for MLC tasks, the analogous sigmoid bottleneck results in exponentially many unargmaxable label combinations. We explain how to detect these unargmaxable outputs and demonstrate their presence in three widely used MLC datasets. We then show that they can be prevented in practice by introducing a Discrete Fourier Transform (DFT) output layer, which guarantees that all sparse label combinations with up to $k$ active labels are argmaxable. Our DFT layer trains faster and is more parameter efficient, matching the F1@k score of a sigmoid layer while using up to 50% fewer trainable parameters. Our code is publicly available at https://github.com/andreasgrv/sigmoid-bottleneck.

arxiv情報

著者 Andreas Grivas,Antonio Vergari,Adam Lopez
発行日 2024-01-29 17:14:01+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.LG パーマリンク