stMind

about Tech, Computer vision and Machine learning

Lambda Labs GPU CloudでJAX/Flax

MacのMetalを使って、手持ちのM1 MacにもJAX/Flaxの実行環境を作ることは出来るのですが、 実際に学習をしようとしてもエラーで詰まってしまうことが多く、JAX/Flaxを実行できる環境を探していました。

Colabを使っても良いのですが、学習を実行するだけでなくて、JAXのビルド自体も試してみたいと思ったので、Lambda Labs GPU Cloudで実行してみることにしました。

実行したのは、FlaxのチュートリアルにあるMNISTの画像分類モデル(CNN)の学習です。

CPU実行するだけであれば何もする必要はなかったのですが、GPUを使う場合には少しだけ苦労しました。

GPU実行時のエラーとTF_FORCE_GPU_ALLOW_GROWTH

最初にGPUで実行したとき、次のようなエラーが出ました。

...
2023-09-18 05:06:37.640924: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
...
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

今回のチュートリアルでは、tensorflow_datasetsを使ってMNISTのデータセットをロードしていたけれど、tensorflowは何も設定しないと、初期化時にGPUメモリの大部分を割り当ててしまうので、TF_FORCE_GPU_ALLOW_GROWTH=trueとしておかないとエラーになる可能性がある様子。

stackoverflow.com

github.com

github.com

最後のIssueでは、PyTorchと一緒に使った場合にCUDNNのエラーになったようで、こちらはXLA_PYTHON_CLIENT_MEM_FRACTION=.88としてJAXのGPU割り当てを制限する方法で解決していた。

MNISTの学習実行

export TF_FORCE_GPU_ALLOW_GROWTH=trueとして実行すると、Lambda Labs GPU CloudでJAX/Flaxで書いたCNNモデルが学習できました。

...
2023-09-18 05:25:48.968484: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
...
train epoch: 10, loss: 0.007537598721683025, accuracy: 99.84666442871094
test epoch: 10, loss: 0.032926980406045914, accuracy: 99.0184326171875