stMind

about Tech, Computer vision and Machine learning

JAXとcomposable program transformations

https://github.com/google/jaxのAboutは、次のように記述されています。

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Composable transformationsはどういうことなのか? NeurIPS2020: JAX Ecosystem Meetupの動画で、DeepMindのエンジニアの方が解説されていました。

NeurIPS 2020: JAX Ecosystem Meetup - YouTube

例として、次の関数を考えます。

def fn(x, y):
  return x**2 + y

fn(1., 2.) # (1**2 + 2) = 3

これに対して、gradientはどう書けるか?

df_dx = grad(fn)
df_dx(1., 2.) # df_dx = 2*x = 2*1 = 2

ここで、gradは関数を返す関数で、df_dxも関数になる。そして、通常の関数呼び出しで使用することができる。

さらに、second order gradientはどう書けるか?

df2_dx = grad(grad(fn))
df2_dx(1., 2.) # df2_dx = d(2*x)_dx = 2

gradはcomposableなため、gradをもう一つ追加するだけで良い。

composableなのはgradだけでなく、他の変換も使用することができる。compiled second-order gradientは以下のように実行できる。

df2_dx = jit(grad(grad(fn)))
df2_dx(1., 2.) # 2, ここでcompileされる
df2_dx(1., 2.) # 2, XLA pre compileのコードを実行、一回目よりも早い実行ができる

さらに、バッチ計算もcomposableに付け加えることができる。(batched compiled second-order gradient)

df2_dx = vmap(jit(grad(grad(fn))))
xs = jnp.ones((batch_size,))
df2_dx(xs, 2 * xs) # [2, 2] if batch_size=2

複数のアクセラレータ(GPUなど)で実行する場合も、composableに付け加えることができる。(multi-gpu batched compiled second-order gradient)

df2_dx = pmap(vmap(jit(grad(grad(fn)))))
xs = jnp.ones((num_gpus, batch_size,))
df2_dx(xs, 2 * xs) # [[2, 2], [2, 2]] if batch_size=2 and num_gpus=2

まとめ

以上が約5分のプレゼンで解説されていた内容ですが、分かりやすくて、 変換の組み合わせってそういうことか!と感動しました。 また動画には、HaikuやOptaxといったEcosystemの話や、他にもGANsなど様々なJAX実装の例があり、勉強になりました。 前回、JAX/FlaxでViTを実装してみましたが、今年はJAXをもっと使っていこうと思います。