Cramer Type Distances for Learning Gaussian Mixture Models by Gradient Descent

要約

混合ガウス モデル (単に GMM とも呼ばれる) の学習は、機械学習において重要な役割を果たします。
表現力と解釈可能性で知られる混合ガウス モデルは、統計学、コンピューター ビジョンから分布強化学習まで、幅広い用途に使用できます。
ただし、現時点では、これらのモデルに適合または学習できる既知のアルゴリズムはほとんどなく、その一部には期待値最大化アルゴリズムやスライス ワッサースタイン ディスタンスが含まれます。
ニューラル ネットワークの一般的な学習プロセスである勾配降下法と互換性のあるアルゴリズムはさらに少数です。
この論文では、単変量 1 次元の場合で 2 つの GMM の閉じた式を導出し、一般的な多変量 GMM を学習するための Sliced Cram\’er 2- distance と呼ばれる距離関数を提案します。
私たちのアプローチには、以前の多くの方法に比べていくつかの利点があります。
まず、単変量の場合に閉じた形式の式があり、一般的な機械学習ライブラリ (PyTorch や TensorFlow など) を使用して計算と実装が簡単です。
2 番目に、勾配降下法と互換性があるため、GMM をニューラル ネットワークとシームレスに統合できます。
第三に、GMM をデータ ポイントのセットに適合させるだけでなく、ターゲット モデルからサンプリングすることなく別の GMM に直接適合させることもできます。
そして 4 番目に、グローバル勾配の境界性や不偏サンプリング勾配などの理論的な保証があります。
これらの機能は、将来の報酬に関する分布を学習することが目標である分布強化学習と Deep Q ネットワークに特に役立ちます。
また、その有効性を実証するためのおもちゃの例として、ガウス混合分布ディープ Q ネットワークを構築します。
以前のモデルと比較して、このモデルは分布を表すという点でパラメータ効率が高く、より優れた解釈可能性を備えています。

要約(オリジナル)

The learning of Gaussian Mixture Models (also referred to simply as GMMs) plays an important role in machine learning. Known for their expressiveness and interpretability, Gaussian mixture models have a wide range of applications, from statistics, computer vision to distributional reinforcement learning. However, as of today, few known algorithms can fit or learn these models, some of which include Expectation-Maximization algorithms and Sliced Wasserstein Distance. Even fewer algorithms are compatible with gradient descent, the common learning process for neural networks. In this paper, we derive a closed formula of two GMMs in the univariate, one-dimensional case, then propose a distance function called Sliced Cram\’er 2-distance for learning general multivariate GMMs. Our approach has several advantages over many previous methods. First, it has a closed-form expression for the univariate case and is easy to compute and implement using common machine learning libraries (e.g., PyTorch and TensorFlow). Second, it is compatible with gradient descent, which enables us to integrate GMMs with neural networks seamlessly. Third, it can fit a GMM not only to a set of data points, but also to another GMM directly, without sampling from the target model. And fourth, it has some theoretical guarantees like global gradient boundedness and unbiased sampling gradient. These features are especially useful for distributional reinforcement learning and Deep Q Networks, where the goal is to learn a distribution over future rewards. We will also construct a Gaussian Mixture Distributional Deep Q Network as a toy example to demonstrate its effectiveness. Compared with previous models, this model is parameter efficient in terms of representing a distribution and possesses better interpretability.

arxiv情報

著者 Ruichong Zhang
発行日 2023-07-13 13:43:02+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.LG, stat.ML パーマリンク