GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent

要約

デシジョン ツリー (DT) は、解釈可能性が高いため、多くの機械学習タスクに一般的に使用されます。
ただし、DT は非凸で微分不可能であるため、データから DT を学習することは困難な最適化問題です。
したがって、一般的なアプローチでは、各内部ノードで局所的に不純物を最小限に抑える貪欲な成長アルゴリズムを使用して DT を学習します。
残念ながら、この貪欲な手順では不正確なツリーが生成される可能性があります。
この論文では、勾配降下法を使用してハードな軸揃え DT を学習するための新しいアプローチを紹介します。
提案された方法は、密な DT 表現でストレートスルー演算子を使用したバックプロパゲーションを使用して、すべてのツリー パラメーターを共同で最適化します。
私たちのアプローチは、バイナリ分類ベンチマークで既存の手法を上回り、マルチクラス タスクで競合する結果を達成します。
このメソッドは、https://github.com/s-marton/GradTree で入手できます。

要約(オリジナル)

Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The method is available under: https://github.com/s-marton/GradTree

arxiv情報

著者 Sascha Marton,Stefan Lüdtke,Christian Bartelt,Heiner Stuckenschmidt
発行日 2024-08-19 14:34:22+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク