LoginSignup
34
53

More than 5 years have passed since last update.

[Keras/TensorFlow] Kerasでweightの保存と読み込み利用

Last updated at Posted at 2017-04-12

目的

ゼロからKerasとTensorFlow(TF)を自由自在に動かせるようになる。
そのための、End to Endの作業ログ(備忘録)を残す。
※環境はMacだが、他のOSでの汎用性を保つように意識。
※アジャイルで執筆しており、精度を逐次高めていく予定。

環境

  • Mac: 10.12.3
  • Python: 3.6
  • TensorFlow: 1.0.1
  • Keras: 2.0.2

To Do

Keras(Tensorflow)の環境構築
KerasでMINSTの学習と予測
KerasでTensorBoardの利用
Kerasで重みファイルの保存/読み込み<---いまココ
Kerasで自前データの学習と予測
Kerasで転移学習

流れ

重みファイルの保存と読み込みの方法を理解して、新しい入力データで予測精度を試す。

学習済み重みデータの保存

参考:https://keras.io/ja/models/about-keras-models/

モデルの保存、重み情報なし

json_string = model.to_json()
open('mnist_mlp_model.json', 'w').write(json_string)

重みの保存

model.save_weights('mnist_mlp_weights.h5')

mnist_mlp.pyの末尾にファイル保存を追加。epochs数は1に変更
https://github.com/fchollet/keras/blob/master/examples/mnist_mlp.py

mnist_mlp_weights.py
from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop


batch_size = 128
num_classes = 10
epochs = 1

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

### save weights
json_string = model.to_json()
open('mnist_mlp_model.json', 'w').write(json_string)
model.save_weights('mnist_mlp_weights.h5')

作成したファイルを実行する。

python mnist_mlp_weights.py

実行すると、mnist_mlp_model.jsonとmnist_mlp_weights.h5が作成される。

mnist_mlp_model.json
{"class_name": "Sequential", "config": [{"class_name": "Dense", "config": {"name": "dense_43", "trainable": true, "batch_input_shape": [null, 784], "dtype": "float32", "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_29", "trainable": true, "rate": 0.2}}, {"class_name": "Dense", "config": {"name": "dense_44", "trainable": true, "units": 512, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dropout", "config": {"name": "dropout_30", "trainable": true, "rate": 0.2}}, {"class_name": "Dense", "config": {"name": "dense_45", "trainable": true, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}], "keras_version": "2.0.2", "backend": "tensorflow"}

学習済み重みデータの読み込み

パターン1:h5ファイルからの読み込み

以下で読み込み

model.load_weights('mnist_mlp_weights.h5')

パターン2:Kerasから自動ダウンロード

参考:https://keras.io/ja/applications/
KerasではweightsオプションでImageNetの学習済み重みの読み込みを簡単に選択できる。

from keras.applications.vgg16 import VGG16
model = VGG16(weights='imagenet', include_top=False)

以下のモデルについて、ImageNetのデータがある。

  • Xception
  • VGG16
  • VGG19
  • ResNet50
  • InceptionV3

h5の種類は以下のリンクから確認できる。topなしの場合数十MB、ありの場合は数百MBのファイルである。
https://github.com/fchollet/deep-learning-models/releases/

新しい画像で予測

先程はmlpのサンプルスクリプトを利用したが、ここではもう少し複雑なモデルであるvgg16を利用する。
Kerasの公式サイトのvgg16のサンプルを利用:https://keras.io/ja/applications/
結果表示で名前が表示されるために、include_top=Trueに変更する必要あり。
象の画像elephant.jpgはgoogle画像検索して適宜調達する。
結果表示の方法は以下を参考。
http://aidiary.hatenablog.com/entry/20170104/1483535144

実行すると、重みファイルを読み込むため少々時間がかかる。

vgg16_imagenet_prelearnes.py
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import numpy as np

model = VGG16(weights='imagenet', include_top=True)

img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

features = model.predict(x)

### show result
from keras.applications.vgg16 import decode_predictions
results = decode_predictions(features, top=5)[0]
for result in results:
    print(result)

実行する。

python vgg16_imagenet_prelearnes.py

結果、象と予測されていることが確認できる。

('n02504013', 'Indian_elephant', 0.7175734)
('n02504458', 'African_elephant', 0.24314526)
('n01871265', 'tusker', 0.036461893)
('n02437312', 'Arabian_camel', 0.00091134955)
('n01704323', 'triceratops', 0.00051373575)

End

34
53
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
34
53