2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[tensorflow , keras , mnist] mnistのデータから各ラベルごとにn枚ずつ取り出し10*n枚のデータを作成する

Posted at

前置き

なぜわざわざデータ量を少なく取りたいのか

  • データ量が少ない場合にも良い精度が出るモデルを考えたいため
  • 各ラベルごとに偏りをなくしたいため

コード

github
mnistのロードを行う

(x_train, y_train), (x_test, y_test) = mnist.load_data()

y_trainをpandasのdataframeに一度格納し、そこからdataframeを分割、indexを取り出す。という流れでn=100で取り出します。

#各label100枚ずつ足りだすためのコード、pandasを用いて行う
df = pd.DataFrame(columns=["label"])
df["label"] = y_train.reshape([-1])

list_0 = df.loc[df.label==0].sample(n=100)#n=100でsampling
list_1 = df.loc[df.label==1].sample(n=100)
list_2 = df.loc[df.label==2].sample(n=100)
list_3 = df.loc[df.label==3].sample(n=100)
list_4 = df.loc[df.label==4].sample(n=100)
list_5 = df.loc[df.label==5].sample(n=100)
list_6 = df.loc[df.label==6].sample(n=100)
list_7 = df.loc[df.label==7].sample(n=100)
list_8 = df.loc[df.label==8].sample(n=100)
list_9 = df.loc[df.label==9].sample(n=100)

label_list = pd.concat([list_0,list_1,list_2,list_3,list_4,list_5,list_6,list_7,list_8,
                       list_9])
label_list = label_list.sort_index()
label_idx = label_list.index.values

train_label = label_list.label.values

"""
x_trainからlabel用のdataframe.indexを取り出すことでlabelに対応したデータを取り出す。
"""
x_train = x_train[label_idx]
y_train= train_label
x_train = x_train / 255
x_test = x_test / 255

これで各ラベル100枚ごとにサンプリングできました。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?