Fast Training of Diffusion Models with Masked Transformers

要約

マスクされたトランスを使用して大規模な拡散モデルをトレーニングするための効率的なアプローチを提案します。
マスクされた変換器は表現学習のために広く研究されてきましたが、生成学習への応用は視覚領域ではあまり研究されていません。
私たちの研究は、マスクされたトレーニングを利用して拡散モデルのトレーニング コストを大幅に削減した最初の研究です。
具体的には、トレーニング中に拡散入力画像内のパッチの高い割合 (\emph{例}、50\%) をランダムにマスクアウトします。
マスクされたトレーニングの場合、マスクされていないパッチでのみ動作するトランスフォーマー エンコーダーと完全なパッチで動作する軽量のトランスフォーマー デコーダーで構成される非対称エンコーダー/デコーダー アーキテクチャを導入します。
完全なパッチの長期的な理解を促進するために、マスクされていないパッチのスコアを学習するノイズ除去スコア マッチング目標に、マスクされたパッチを再構築する補助タスクを追加します。
ImageNet-256$\times$256 での実験では、私たちのアプローチが元のトレーニング時間のわずか 31\% を使用して、最先端の拡散変換器 (DiT) モデルと同じパフォーマンスを達成できることがわかりました。
したがって、私たちの方法では、生成パフォーマンスを犠牲にすることなく、拡散モデルの効率的なトレーニングが可能になります。

要約(オリジナル)

We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (\emph{e.g.}, 50\%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256$\times$256 show that our approach achieves the same performance as the state-of-the-art Diffusion Transformer (DiT) model, using only 31\% of its original training time. Thus, our method allows for efficient training of diffusion models without sacrificing the generative performance.

arxiv情報

著者 Hongkai Zheng,Weili Nie,Arash Vahdat,Anima Anandkumar
発行日 2023-06-15 17:38:48+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.CV, cs.LG パーマリンク