MacのMetalを使って、手持ちのM1 MacにもJAX/Flaxの実行環境を作ることは出来るのですが、 実際に学習をしようとしてもエラーで詰まってしまうことが多く、JAX/Flaxを実行できる環境を探していました。
Colabを使っても良いのですが、学習を実行するだけでなくて、JAXのビルド自体も試してみたいと思ったので、Lambda Labs GPU Cloudで実行してみることにしました。
実行したのは、FlaxのチュートリアルにあるMNISTの画像分類モデル(CNN)の学習です。
CPU実行するだけであれば何もする必要はなかったのですが、GPUを使う場合には少しだけ苦労しました。
GPU実行時のエラーとTF_FORCE_GPU_ALLOW_GROWTH
最初にGPUで実行したとき、次のようなエラーが出ました。
... 2023-09-18 05:06:37.640924: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR ... jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
今回のチュートリアルでは、tensorflow_datasetsを使ってMNISTのデータセットをロードしていたけれど、tensorflowは何も設定しないと、初期化時にGPUメモリの大部分を割り当ててしまうので、TF_FORCE_GPU_ALLOW_GROWTH=trueとしておかないとエラーになる可能性がある様子。
最後のIssueでは、PyTorchと一緒に使った場合にCUDNNのエラーになったようで、こちらはXLA_PYTHON_CLIENT_MEM_FRACTION=.88
としてJAXのGPU割り当てを制限する方法で解決していた。
MNISTの学習実行
export TF_FORCE_GPU_ALLOW_GROWTH=true
として実行すると、Lambda Labs GPU CloudでJAX/Flaxで書いたCNNモデルが学習できました。
... 2023-09-18 05:25:48.968484: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0. ... train epoch: 10, loss: 0.007537598721683025, accuracy: 99.84666442871094 test epoch: 10, loss: 0.032926980406045914, accuracy: 99.0184326171875