LoginSignup
6
7

More than 3 years have passed since last update.

自作のデータセットを Keras のチュートリアル用データっぽく読み込む

Last updated at Posted at 2019-06-17

やったこと

Kerasでチュートリアル用のMnistデータを呼びだす際には言わずもがな以下のコードを実行する。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

これを自作のデータでそれっぽく行う。

(X_train, y_train), (X_test, y_test) = load_data("自作のデータセット名")

ちなみに本記事は自作のデータセットをAugmentationを行い、train・testデータに分割を行なった以前の記事の続き。

以前の記事で行なったこと

自作の少数の画像データから
 ↓
Augmentationを行い、データ数のカサ増し
 ↓
Trainデータ, Testデータ, Trainラベル, Testラベルをlist形式で取得

NumPyのバイナリファイルの活用

以前の記事の続きなので以下の変数が取得されている前提。

training_images # list / Trainデータ
training_labels # list / Testデータ
test_images # list / Trainラベル
test_labels # list / Testラベル

NumPyのバイナリファイルについてはこちらの記事を参考にさせて頂きました。

NumPyでは配列ndarrayをNumPy独自のバイナリファイル(npy, npz)で保存することが可能。データ型dtypeや形状shapeなどの情報を保持したまま書き込み・読み込み(出力・入力)ができる。

要はNumPyのオブジェクトを編集不可の状態で保存してくれるのがnpz形式のファイル(という理解)

以前の記事ではdog, catのデータを用いたのでデータセット名は以下を用いる。

'cats_dogs'

list → np.array → npz形式 へ
ちなみにnp.savezを用いてnpz形式ファイルの作成。

np.savez('パス文字列', array)

今回はカレントディレクトリに作成。

np.savez('cats_dogs_training_data.npz', np.array(training_images))
np.savez('cats_dogs_training_labels.npz', np.array(training_labels))
np.savez('cats_dogs_test_data.npz', np.array(test_images))
np.savez('cats_dogs_test_labels.npz', np.array(test_labels))

load_data関数の定義

まずは上記で保存したデータの読み込み、呼び出し方法の確認

デフォルトでは上の例のように保存時に引数に指定した順番にarr_0, arr_1...という名前が自動的に付けられる。

ようなので、以下の通り。

# 読み込み
train = np.load('cats_dogs_training_data.npz')

# 呼び出し
train['arr_0']

では早速
load_data関数の作成。

def load_data(datasetname):

   """
   note : 呼び出したいデータセット名の train/test 両データ・両ラベルをnp.array形式で取得する関数
   ----------
   datasetname : str
   """

    file_train_data = np.load(datasetname + "_training_data.npz")
    train = file_train_data["arr_0"]

    file_train_label = np.load(datasetname + "_training_labels.npz")
    train_labels = file_train_label["arr_0"]

    file_test_data = np.load(datasetname + "_test_data.npz")
    test = file_test_data["arr_0"]

    file_test_label = np.load(datasetname + "_test_labels.npz")
    test_labels = file_test_label["arr_0"]

    return (train, train_labels), (test, test_labels)

load_data関数の呼び出し

(X_train, y_train), (X_test, y_test) = load_data("cats_dogs")
6
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
7