1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Tensorflow/Kerasで学習済み「重み」を保存/復元する方法

Last updated at Posted at 2021-02-07

重みの保存

機械学習を行っている際に毎回同じ学習を行うのは時間の無駄なので、学習済み重みとして保存しておきましょう。
保存した学習済み重みを復元(load)して認識などを行う方法も紹介します。
MNISTを例に紹介します。
公式日本語解説:https://www.tensorflow.org/tutorials/keras/save_and_load?hl=ja

保存の方法の種類

重みはいくつかのファイルが指定のフォルダーに保存されます。
以下のようなファイルが保存されます。

  • checkpoint
  • checkpoints.data-00000-of-00001
  • checkpoints.index

モデルと重みを保存

重みだけでなくモデルも一緒に保存する方法は別の記事に載せています。
そちらではモデルも復元してくれるので、一から定義する必要はないです。
Qiita: Tensorflow/Kerasで学習済み「モデル/重み」を保存/復元する方法

重みの保存

model.save()でモデルと重みを保存します。
これはmodel.fit()で学習した後に実行します。

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import RMSprop

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(
    x_train.shape[0], x_train.shape[1] * x_train.shape[2])
x_test = x_test.reshape(x_test.shape[0],   x_test.shape[1] * x_test.shape[2])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255
x_test /= 255

num_classes = len(list(set(y_train)))

y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

input_shape = x_train[0].shape

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=input_shape))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy'])
model.summary()

batch_size = 16
epochs = 3

history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test))

model.save_weights('./checkpoints/checkpoint')

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

##重みの復元

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import RMSprop

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(
    x_train.shape[0], x_train.shape[1] * x_train.shape[2])
x_test = x_test.reshape(x_test.shape[0],   x_test.shape[1] * x_test.shape[2])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255
x_test /= 255

num_classes = len(list(set(y_train)))

y_test = to_categorical(y_test, num_classes)

input_shape = x_test[0].shape

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=input_shape))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])
model.summary()
model.load_weights('./checkpoints/checkpoint')


score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?