Tabular Data Contrastive Learning via Class-Conditioned and Feature-Correlation Based Augmentation

要約

対照学習は、最初に元のデータの同様のビューを作成し、次にデータとそれに対応するビューが埋め込み空間内で近くなるように促すことによるモデルの事前トレーニング手法です。
対照学習は、直観的で効果的なドメイン固有の拡張技術のおかげで、画像データと自然言語データで成功を収めています。
それにもかかわらず、表形式のドメインでは、ビューを作成するための主な拡張手法は、値の交換による表形式のエントリの破損によるものであり、それほど健全でも効果的でもありません。
私たちは、この拡張手法に対するシンプルかつ強力な改善策、つまりクラスのアイデンティティを条件とした表形式データの破損を提案します。
具体的には、アンカー行から特定の表エントリを破損する場合、テーブル全体から同じ特徴列の値をランダムに均一にサンプリングするのではなく、アンカー行と同じクラス内にあると識別された行からのみサンプリングします。
半教師あり学習設定を想定し、テーブルのすべての行にわたってクラス ID を取得するために擬似ラベリング手法を採用します。
また、特徴相関構造に基づいて破損する特徴を選択するという新しいアイデアも検討します。
広範な実験により、提案されたアプローチが、表形式のデータ分類タスクに対して従来の破損手法よりも一貫して優れていることが示されています。
私たちのコードは https://github.com/willtop/Tabular-Class-Conditioned-SSL で入手できます。

要約(オリジナル)

Contrastive learning is a model pre-training technique by first creating similar views of the original data, and then encouraging the data and its corresponding views to be close in the embedding space. Contrastive learning has witnessed success in image and natural language data, thanks to the domain-specific augmentation techniques that are both intuitive and effective. Nonetheless, in tabular domain, the predominant augmentation technique for creating views is through corrupting tabular entries via swapping values, which is not as sound or effective. We propose a simple yet powerful improvement to this augmentation technique: corrupting tabular data conditioned on class identity. Specifically, when corrupting a specific tabular entry from an anchor row, instead of randomly sampling a value in the same feature column from the entire table uniformly, we only sample from rows that are identified to be within the same class as the anchor row. We assume the semi-supervised learning setting, and adopt the pseudo labeling technique for obtaining class identities over all table rows. We also explore the novel idea of selecting features to be corrupted based on feature correlation structures. Extensive experiments show that the proposed approach consistently outperforms the conventional corruption method for tabular data classification tasks. Our code is available at https://github.com/willtop/Tabular-Class-Conditioned-SSL.

arxiv情報

著者 Wei Cui,Rasa Hosseinzadeh,Junwei Ma,Tongzi Wu,Yi Sui,Keyvan Golestan
発行日 2024-04-30 14:11:15+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG, stat.ML パーマリンク