GPUを認識していない場合には、以下のようなメッセージが出ます。
In [1]: import jax In [2]: jax.devices() An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. Out[2]: [CpuDevice(id=0)]
インストールされているjaxとjaxlibはCPU-onlyになっている。
$ pip list | grep jax jax 0.4.34 jaxlib 0.4.34 jaxopt 0.8.3 jaxtyping 0.2.33 ott-jax 0.4.7
GPU対応をインストールする。
$ pip install "jax[cuda]" $ pip list | grep jax jax 0.4.34 jax-cuda12-pjrt 0.4.34 jax-cuda12-plugin 0.4.34 jaxlib 0.4.34 jaxopt 0.8.3 jaxtyping 0.2.33 ott-jax 0.4.7
先程のメッセージが出なくなり、CudaDeviceが表示されます。
In [1]: import jax In [2]: jax.devices() Out[2]: [CudaDevice(id=0)]