stMind

about Tech, Computer vision and Machine learning

JaxがGPUを認識してない場合の対応

github.com

GPUを認識していない場合には、以下のようなメッセージが出ます。

In [1]: import jax

In [2]: jax.devices()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Out[2]: [CpuDevice(id=0)]

インストールされているjaxとjaxlibはCPU-onlyになっている。

$ pip list | grep jax
jax                           0.4.34
jaxlib                        0.4.34
jaxopt                        0.8.3
jaxtyping                     0.2.33
ott-jax                       0.4.7

GPU対応をインストールする。

$ pip install "jax[cuda]"
$ pip list | grep jax
jax                           0.4.34
jax-cuda12-pjrt               0.4.34
jax-cuda12-plugin             0.4.34
jaxlib                        0.4.34
jaxopt                        0.8.3
jaxtyping                     0.2.33
ott-jax                       0.4.7

先程のメッセージが出なくなり、CudaDeviceが表示されます。

In [1]: import jax

In [2]: jax.devices()
Out[2]: [CudaDevice(id=0)]