Automatic Functional Differentiation in JAX

要約

JAX を拡張して、高階関数 (関数と演算子) を自動的に区別する機能を追加します。
関数を配列の一般化として表すことにより、JAX の既存の基本システムをシームレスに使用して高階関数を実装します。
いくつかの主要なタイプの関数を構築するための基礎的な構成要素として機能する一連のプリミティブ演算子を紹介します。
導入されたすべてのプリミティブ演算子について、順方向および逆方向モードの自動微分のための JAX の内部プロトコルに合わせて、線形化と転置の両方のルールを導出して実装します。
この機能強化により、従来関数に使用されてきた同じ構文で関数を区別できるようになります。
結果として得られる関数勾配自体は、Python ですぐに呼び出すことができる関数です。
関数派生関数が不可欠なアプリケーションを通じて、このツールの有効性とシンプルさを紹介します。
この作品のソースコードは https://github.com/sail-sg/autofd で公開されています。

要約(オリジナル)

We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as a generalization of arrays, we seamlessly use JAX’s existing primitive system to implement higher-order functions. We present a set of primitive operators that serve as foundational building blocks for constructing several key types of functionals. For every introduced primitive operator, we derive and implement both linearization and transposition rules, aligning with JAX’s internal protocols for forward and reverse mode automatic differentiation. This enhancement allows for functional differentiation in the same syntax traditionally use for functions. The resulting functional gradients are themselves functions ready to be invoked in python. We showcase this tool’s efficacy and simplicity through applications where functional derivatives are indispensable. The source code of this work is released at https://github.com/sail-sg/autofd .

arxiv情報

著者 Min Lin
発行日 2023-11-30 17:23:40+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG, cs.PL パーマリンク