前置き
なぜわざわざデータ量を少なく取りたいのか
- データ量が少ない場合にも良い精度が出るモデルを考えたいため
- 各ラベルごとに偏りをなくしたいため
コード
github
mnistのロードを行う
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_trainをpandasのdataframeに一度格納し、そこからdataframeを分割、indexを取り出す。という流れでn=100で取り出します。
#各label100枚ずつ足りだすためのコード、pandasを用いて行う
df = pd.DataFrame(columns=["label"])
df["label"] = y_train.reshape([-1])
list_0 = df.loc[df.label==0].sample(n=100)#n=100でsampling
list_1 = df.loc[df.label==1].sample(n=100)
list_2 = df.loc[df.label==2].sample(n=100)
list_3 = df.loc[df.label==3].sample(n=100)
list_4 = df.loc[df.label==4].sample(n=100)
list_5 = df.loc[df.label==5].sample(n=100)
list_6 = df.loc[df.label==6].sample(n=100)
list_7 = df.loc[df.label==7].sample(n=100)
list_8 = df.loc[df.label==8].sample(n=100)
list_9 = df.loc[df.label==9].sample(n=100)
label_list = pd.concat([list_0,list_1,list_2,list_3,list_4,list_5,list_6,list_7,list_8,
list_9])
label_list = label_list.sort_index()
label_idx = label_list.index.values
train_label = label_list.label.values
"""
x_trainからlabel用のdataframe.indexを取り出すことでlabelに対応したデータを取り出す。
"""
x_train = x_train[label_idx]
y_train= train_label
x_train = x_train / 255
x_test = x_test / 255
これで各ラベル100枚ごとにサンプリングできました。