Python
numpy

NumPy でデータセット (ndarray) を任意の割合に分割する

More than 1 year has passed since last update.


方法

numpy.split() を使います。


(例 1) データセットを train : test = 7 : 3 に分割する

import numpy as np

ds = np.arange(128) # array([0, 1, 2, ..., 127])

train, test = np.split(ds, [int(ds.size * 0.7)])

train # array([[0, 1, ..., 88])
test # array([[89, 90, ..., 127])

train.size # 89 ≈ 128 * 0.7 = 89.6
test.size # 39 ≈ 128 * 0.3 = 38.4


(例 2) データセットを train : test : validation = 6 : 2 : 2 に分割する

import numpy as np

ds = np.arange(128) # array([0, 1, 2, ..., 127])

indices = [int(ds.size * n) for n in [0.6, 0.6 + 0.2]] # [76, 102]
train, test, validation = np.split(ds, indices)

train # array([0, 1, ..., 75])
test # array([76, 77, ..., 101])
validation # array([102, 103, ..., 127])

train.size # 76 ≈ 128 * 0.6 = 76.8
test.size # 26 ≈ 128 * 0.2 = 25.6
validation.size # 26 ≈ 128 * 0.2 = 25.6


参考