要約
拡散確率モデル (DPM) は、さまざまな機械学習ドメインで目覚ましい進歩を遂げました。
ただし、高品質の合成サンプルを実現するには、通常、多数のサンプリング ステップを実行する必要があるため、リアルタイムのサンプル合成の可能性が妨げられます。
知識の蒸留による従来の高速サンプリング アルゴリズムは、事前トレーニングされたモデルの重みと離散時間ステップ シナリオに依存しているため、目標を達成するには追加のトレーニング セッションが必要です。
これらの問題に対処するために、速度推定モデルの現在の瞬間の出力が前の瞬間の出力に「追いつく」ことを促進するキャッチアップ蒸留 (CUD) を提案します。
具体的には、CUD は元の常微分方程式 (ODE) トレーニング目標を調整して、現在の瞬間の出力をグランド トゥルース ラベルと前の瞬間の出力の両方に合わせます。これにより、ルンゲ クッタ ベースのマルチステップ アライメント蒸留を利用して、非同期を防止しながら正確な ODE 推定を実現します。
更新情報。
さらに、連続タイムステップシナリオの下で CUD の設計空間を調査し、適切な戦略を決定する方法を分析します。
CUD の有効性を実証するために、CIFAR-10、MNIST、ImageNet-64 で徹底的なアブレーション実験と比較実験を実施しています。
CIFAR-10 では、1 セッションのトレーニングで 15 ステップでサンプリングすることにより 2.80 の FID が得られ、追加のトレーニングを使用して 1 ステップでサンプリングすることにより、新しい最先端の FID 3.37 が得られます。
後者の結果では、バッチ サイズ 128 で 620,000 回の反復のみが必要でしたが、一貫性蒸留ではバッチ サイズ 256 で 2100,000 回の反復が必要でした。コードは https://anonymous.4open.science/r/Catch でリリースされています。
-昇圧蒸留-E31F。
要約(オリジナル)
Diffusion Probability Models (DPMs) have made impressive advancements in various machine learning domains. However, achieving high-quality synthetic samples typically involves performing a large number of sampling steps, which impedes the possibility of real-time sample synthesis. Traditional accelerated sampling algorithms via knowledge distillation rely on pre-trained model weights and discrete time step scenarios, necessitating additional training sessions to achieve their goals. To address these issues, we propose the Catch-Up Distillation (CUD), which encourages the current moment output of the velocity estimation model “catch up” with its previous moment output. Specifically, CUD adjusts the original Ordinary Differential Equation (ODE) training objective to align the current moment output with both the ground truth label and the previous moment output, utilizing Runge-Kutta-based multi-step alignment distillation for precise ODE estimation while preventing asynchronous updates. Furthermore, we investigate the design space for CUDs under continuous time-step scenarios and analyze how to determine the suitable strategies. To demonstrate CUD’s effectiveness, we conduct thorough ablation and comparison experiments on CIFAR-10, MNIST, and ImageNet-64. On CIFAR-10, we obtain a FID of 2.80 by sampling in 15 steps under one-session training and the new state-of-the-art FID of 3.37 by sampling in one step with additional training. This latter result necessitated only 620k iterations with a batch size of 128, in contrast to Consistency Distillation, which demanded 2100k iterations with a larger batch size of 256. Our code is released at https://anonymous.4open.science/r/Catch-Up-Distillation-E31F.
arxiv情報
著者 | Shitong Shao,Xu Dai,Shouyi Yin,Lujun Li,Huanran Chen,Yang Hu |
発行日 | 2023-05-30 16:40:27+00:00 |
arxivサイト | arxiv_id(pdf) |
提供元, 利用サービス
arxiv.jp, Google