1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Tensorflow/Kerasで学習済み「モデル/重み」を保存/復元する方法

Last updated at Posted at 2021-02-07

モデル/重みの保存

機械学習を行っている際に毎回同じ学習を行うのは時間の無駄なので、学習済みモデル/重みとして保存しておきましょう。
保存した学習済みモデル/重みを復元(load)して認識などを行う方法も紹介します。
MNISTを例に紹介します。
公式日本語解説:https://www.tensorflow.org/tutorials/keras/save_and_load?hl=ja

保存の方法の種類

ファイルとして保存/フォルダー形式で保存の2つのパターンがあります。
ファイルは.h5という拡張子のHDF5ファイルで保存します。
フォルダー形式だとややこしいので、1つのファイル(HDF5)で保存する方法を紹介します。

重みだけ保存

モデルは保存しないで重みだけ保存する方法は別記事で紹介します。
こちらは復元の際は先にモデルを別で定義する必要があります。
Qiita: Tensorflow/Kerasで学習済み「重み」を保存/復元する方法

モデル/重みの保存

model.save()でモデルと重みを保存します。
これはmodel.fit()で学習した後に実行します。

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import RMSprop

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2])
x_test  = x_test.reshape(x_test.shape[0],   x_test.shape[1]  * x_test.shape[2])

x_train = x_train.astype('float32')
x_test  = x_test.astype('float32')

x_train /= 255
x_test  /= 255

num_classes = len(list(set(y_train)))

y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

input_shape = x_train[0].shape

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=input_shape))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])
model.summary()

batch_size = 16
epochs = 3

history  = model.fit(x_train, y_train, 
                     batch_size=batch_size,
                     epochs=epochs,
                     verbose=1,
                     validation_data=(x_test, y_test))

model.save('mnist_widgets.h5') #モデルと重みの保存

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

##モデル/重みの復元

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model


(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test  = x_test.reshape(x_test.shape[0],   x_test.shape[1]  * x_test.shape[2])

x_test  = x_test.astype('float32')

x_test  /= 255

num_classes = len(list(set(y_test)))

y_test = to_categorical(y_test, num_classes)

model = load_model('mnist_widgets.h5') #モデルと重みを復元

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?