Optimizing Automatic Differentiation with Deep Reinforcement Learning

要約

自動微分によるヤコビアンの計算は、機械学習、数値流体力学、ロボット工学、金融などの多くの科学分野で広く普及しています。
ヤコビアン計算における計算数やメモリ使用量がわずかに節約されただけでも、すでにエネルギー消費量と実行時間の大幅な節約につながる可能性があります。
このような節約を可能にする方法は多数存在しますが、一般に、それらは正確なヤコビアンの近似と計算効率を引き換えにします。
この論文では、正確なヤコビアンを計算しながら、深層強化学習 (RL) とクロスカントリー消去と呼ばれる概念を活用して、ヤコビアンの計算に必要な乗算の数を最適化する新しい方法を紹介します。
クロスカントリー消去は、計算グラフ上のすべての頂点の順序付けられた消去としてヤコビ累積を表現する自動微分のフレームワークであり、すべての消去には一定の計算コストがかかります。
必要な乗算の数を最小限に抑える最適な消去順序の探索を、RL エージェントによってプレイされるシングル プレーヤー ゲームとして定式化します。
この方法は、さまざまなドメインから取得したいくつかの関連タスクにおいて、最先端の方法と比較して最大 33% の改善を達成することを実証します。
さらに、取得した消去命令を効率的に実行できるクロスカントリー消去インタープリタを JAX に提供することで、これらの理論上の利点が実際の実行時間の改善につながることを示します。

要約(オリジナル)

Computing Jacobians with automatic differentiation is ubiquitous in many scientific domains such as machine learning, computational fluid dynamics, robotics and finance. Even small savings in the number of computations or memory usage in Jacobian computations can already incur massive savings in energy consumption and runtime. While there exist many methods that allow for such savings, they generally trade computational efficiency for approximations of the exact Jacobian. In this paper, we present a novel method to optimize the number of necessary multiplications for Jacobian computation by leveraging deep reinforcement learning (RL) and a concept called cross-country elimination while still computing the exact Jacobian. Cross-country elimination is a framework for automatic differentiation that phrases Jacobian accumulation as ordered elimination of all vertices on the computational graph where every elimination incurs a certain computational cost. We formulate the search for the optimal elimination order that minimizes the number of necessary multiplications as a single player game which is played by an RL agent. We demonstrate that this method achieves up to 33% improvements over state-of-the-art methods on several relevant tasks taken from diverse domains. Furthermore, we show that these theoretical gains translate into actual runtime improvements by providing a cross-country elimination interpreter in JAX that can efficiently execute the obtained elimination orders.

arxiv情報

著者 Jamie Lohoff,Emre Neftci
発行日 2024-06-07 15:44:33+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.AI, cs.LG パーマリンク