TaskMet: Task-Driven Metric Learning for Model Learning

要約

深層学習モデルは多くの場合、トレーニング手順では認識されない下流タスクにデプロイされます。
たとえば、正確な予測を達成するためだけにトレーニングされたモデルは、一見小さな予測誤差が重大なタスク エラーを引き起こす可能性があるため、下流のタスクで適切に実行するのが難しい場合があります。
標準的なエンドツーエンド学習アプローチは、タスク損失を微分可能にするか、モデルをトレーニングできる微分可能なサロゲートを導入することです。
これらの設定では、目標が矛盾する可能性があるため、タスク損失と予測損失のバランスを注意深く取る必要があります。
私たちは、モデルのパラメーターよりも 1 レベル深いタスク損失信号を取得し、それを使用してモデルがトレーニングされる損失関数のパラメーターを学習することを提案します。これは、予測空間でメトリックを学習することで実行できます。
このアプローチは、最適な予測モデル自体を変更するのではなく、下流のタスクにとって重要な情報を強調するようにモデル学習を変更します。
これにより、元の予測空間でトレーニングされた予測モデルでありながら、必要な下流タスクにとっても価値のある予測モデルという、両方の長所を実現することができます。
私たちは、2 つの主要な設定で行われた実験を通じてアプローチを検証します。1) ポートフォリオの最適化と予算配分を含む意思決定に重点を置いたモデル学習シナリオ、2) 気を散らす状態の騒がしい環境での強化学習。
実験を再現するためのソース コードは https://github.com/facebookresearch/taskmet で入手できます。

要約(オリジナル)

Deep learning models are often deployed in downstream tasks that the training procedure may not be aware of. For example, models solely trained to achieve accurate predictions may struggle to perform well on downstream tasks because seemingly small prediction errors may incur drastic task errors. The standard end-to-end learning approach is to make the task loss differentiable or to introduce a differentiable surrogate that the model can be trained on. In these settings, the task loss needs to be carefully balanced with the prediction loss because they may have conflicting objectives. We propose take the task loss signal one level deeper than the parameters of the model and use it to learn the parameters of the loss function the model is trained on, which can be done by learning a metric in the prediction space. This approach does not alter the optimal prediction model itself, but rather changes the model learning to emphasize the information important for the downstream task. This enables us to achieve the best of both worlds: a prediction model trained in the original prediction space while also being valuable for the desired downstream task. We validate our approach through experiments conducted in two main settings: 1) decision-focused model learning scenarios involving portfolio optimization and budget allocation, and 2) reinforcement learning in noisy environments with distracting states. The source code to reproduce our experiments is available at https://github.com/facebookresearch/taskmet

arxiv情報

著者 Dishank Bansal,Ricky T. Q. Chen,Mustafa Mukadam,Brandon Amos
発行日 2023-12-08 18:59:03+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG, math.OC, stat.ML パーマリンク