stMind

about Tech, Computer vision and Machine learning

データセットをtrain/val/testに分割するコードをnumpyで簡潔に記述する

tl;dr

numpy.splitを使って、aryを3つのsubarrayに分割する。

import numpy as np
train, val, test = 
    np.split(ary, [int(len(ary) * .6), int(len(ary) * .8)])

簡単な説明

データをtrain/testに分割する時、scikit-learnのtrain_test_splitを使うことが多いと思いますが、 train/val/testと分割しようとすると、一度train/testと分けた後でtestに対して再度train_test_splitするなどが必要です。

numpy.split(ary, [a, b])は、第一引数に指定されたaryに対してary[:a], ary[a:b], ary[b:]と分割されるため、一回の処理でデータセットをtrain/val/testに分割することができます。

train/val/testで60/20/20に分割するときは、

np.split(ary, [int(len(ary) * .6), int(len(ary) * .8)])

と指定すればOKです。

ただし、np.splitは単純に分割するだけですので、各セットのラベルの分布を維持するtrain_test_splitのstratifyのような分割をしたい場合には、別途ケアする必要があります。

参照先

以上の内容は、下記stackoverflowの投稿を参照したものです。

stackoverflow.com