Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
8
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

[初心者向け] Python MNISTのサンプルコードを自作データセットで動かすためのload_data自作

はじめに

ディープラーニングやAIに興味を持ち始めてサンプルコードを動かそうとすると、MNISTと呼ばれるデータセットを利用したサンプルが多く登場します。
MNISTとは0から9のラベルで分類されている手書き文字のデータセットで
解像度28x28のグレースケール画像のラベル付けデータセットです。

サンプルコード自体は環境構築さえできれば実行できるのですが、
自分で作成したオリジナルのデータセットが使いたくなりMNISTのコードを見ると下記の一行でデータセットのロードが完結していることが多いです。
※プログラムによってはここから正規化などのクレンジング処理を行います。

(x_train, y_train), (x_test, y_test) = mnist.load_data()

初心者の方がここからいきなり自作のデータセットの作成は非常にハードルが高いと思われます。
そのため、この記事ではmnist.load_dataの代わりに自作のデータセットをmnist形式で作成するプログラムを実装します。

mnist.load_data()

MNISTの仕様は公式のドキュメントにも紹介されています。
https://keras.io/ja/datasets/

使い方は上記のサンプルと同じです。
x_train, y_trainが学習用のデータとラベルが格納されており、
x_test, y_testも同様に検証用データ一式が格納されています。

トレーニングデータについてはこちらの記事が非常にわかりやすい為、共有させていただきます。
機械学習 トレーニングデータの分割と学習・予測・検証

自作のload_data()

今回自作するload_data関数を扱うための準備は
・フォルダごとに画像を分けて保存
これだけです。
想定としては各フォルダごとにラベルを付けていきます。
例えばフォルダAはイヌの画像ファイルから構成されるなど。

importの一覧とソースコードを示します。

import.txt
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import os, glob
my_load_data().py
def my_load_data(folder_str, size):
    print('load_dataset...')
    folders = folder_str.split('__')
    X = []
    Y = []
    for index, fol_name in enumerate(folders):
        files = glob.glob(fol_name + '/*.jpg')
        for file in files:
            image = Image.open(file)
            image = image.resize((size, size))
            image = image.convert('L')
            data = np.asarray(image)
            X.append(data)
            Y.append(index)
    X = np.array(X)
    Y = np.array(Y)
    oh_encoder = OneHotEncoder(categories='auto', sparse=False)
    onehot = oh_encoder.fit_transform(pd.DataFrame(Y))
    X_train, X_test, y_train, y_test = train_test_split(X, onehot, test_size=0.2)
    return X_train, X_test, y_train, y_test

仮引数folder_strには画像を分割したフォルダを指定します。
ラベルを付ける場合、複数フォルダが必要なので'__'で区切ってフォルダ名を指定してください。
サンプルコードは拡張子はjpgですが変更していただいて結構です。
sizeは解像度になります。MNISTは28x28ですので28を指定。
ラベルがonehotらしいので一応変換しておきます。
実際に上記関数を利用する際のmain関数です。

sample.py
import argparse

def main():
    parser = argparse.ArgumentParser(description='sample')
    parser.add_argument('--folder', '-i')
    parser.add_argument('--size', '-s', type=int, default=28)
    args = parser.parse_args()
    X_train, X_test, y_train, y_test = my_load_data(args.folder, args.size)

    # 確認
    print('X_train',X_train)
    print('y_train',y_train)

実行コマンドの一例
f1、f2、f3はカレントディレクトリにある画像が入っているフォルダを想定しています。

python sample.py --folder f1__f2__f3 -s 28

おわりに

今回はMNISTのload_dataを自作のデータで試せるようなmy_load_dataを作成しました。
MNISTのサンプルを動かして楽しんでいただければ幸いです。
動作の不具合や分からない内容などあればコメントお待ちしています。

この記事を作成するに当たり、様々な先人の知恵をお借りしました。
最後に記載させていただきます。
ご一読ありがとうございました。LGTMも良ければお願いします!

参考資料

画像データからnumpy形式に変換する方法
Keras VAEの画像異常検出を理解する
畳み込みオートエンコーダによる画像の再現、ノイズ除去、セグメンテーション

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
8
Help us understand the problem. What are the problem?