stMind

about Tech, Computer vision and Machine learning

Llama3.1を推論で使うために必要なGPUメモリ

huggingface.co

Llama3.1モデルについてのHuggingfaceのブログの紹介です。 このブログの中で、推論時に必要なGPUメモリについて記載があるので、 実際に計算して確かめてみます。

推論時のGPUメモリ計算

developer.nvidia.com

計算式については、以下の通りです。

  • LLM推論時のGPUメモリ要件に大きく影響するのは、モデルの重みとKVキャッシュ
  • モデルの重みは、パラメータ数 * precision in bytes
    • 例えばLlama2-7Bを16 bit precisionでロードすると、7B * sizeof(FP16) で14GB
  • KVキャッシュの計算
    • 以下の式はトークン1つあたりのKVキャッシュのサイズで、最初の2はK行列とV行列を表す
    • これは、バッチ数分で各バッチの入力シーケンスの各トークンに対して必要となる(今回計算したHFの表においてはバッチ数は1)
      • 二つ目の式で、sequence_length以降がトークン1つあたりのKVキャッシュサイズの計算。ここでは、num_heads * dim_headをhidden_sizeとしていて、precision in bytesもsizeof(FP16)と固定した記載になっているが、FP16に限らない

Llama3.1の推論時のメモリ要件

Model Size FP16 FP8 INT4
8B 16 GB 8 GB 4 GB
70B 140 GB 70 GB 35 GB
405B 810 GB 405 GB 203 GB

Huggingfaceから抜粋しました。 モデルサイズは単純に、パラメータ数 * precision in bytesなので、8Bの場合は8B * sizeof(FP16) = 16GB、8B * sizeof(FP8) = 8GB、8B * sizeof(INT4) = 4GBとなっています。

FP16の場合のKVキャッシュについても、Huggingfaceから抜粋します。

Model Size 1k tokens 16k tokens 128k tokens
8B 0.125 GB 1.95 GB 15.62 GB
70B 0.313 GB 4.88 GB 39.06 GB
405B 0.984 GB 15.38 GB 123.05 GB

Llama3.1の論文でハイパーパラメータを参照し、計算してみます。

8Bモデルで16k tokensと128k tokensの場合、num_layersは32、num_headsは8、dim_headは(4096/32) = 128として、

16000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 1.953125
128000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 15.625

ちなみに、1k tokensを計算してみると、

1000 * (2 * 32 * 8 * 128 * 2) / 1024**3
# 0.1220703125

となって表の値と微妙に異なります... シーケンス長を1024として計算すると、0.125GBとなります。

70Bモデルも、num_layersを80、num_headsを8、dim_headを(8192/64)=128として計算すると、表と同じになることが分かります。 ただ、405Bモデルで、num_layersを126、num_headsを8、dim_headを(16384/128)=128として計算すると、表と同じ値になりません。

16000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 7.6904296875
128000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 61.5234375

GitHubのIssueで質問してみたところ、コメントしてくれた方がいました。

github.com

コメントの中にあったRedditのリンクを見てみると、405Bのモデルでは、num_headsが16から8に変更されたようで、確かに履歴にヘッド数を変更したコミットの8-kv-headsが含まれていました。num_headsを16として計算すると、表の値と一致しました。

16000 * (2 * 126 * 16 * (16384/128) * 2) / 1024**3
# 15.380859375
128000 * (2 * 126 * 16 * (16384/128) * 2) / 1024**3
# 123.046875

まとめ

GPUメモリ計算方法はLlama3.1に限らず使えるものなので、利用したいモデルで計算してみて、必要なHWリソースや適用する精度(FP8など)を把握すると良さそうです。