5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Kerasからパラメータ抽出についてのメモ

Posted at

WeightとBiasを見るために

Kerasでモデルをsaveしたときにハマったので、一応の自分用のメモ
もし抽出で悩んだ人がいれば幸い
例として以下のMNISTの3層MLP(下はFunctional API)

# coding utf-8
from __future__ import function
import keras
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Dropout, Input
from keras.optimizers import RMSprop

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784) 
x_test = x_test.reshape(10000, 784)   
x_train = x_train.astype('float32')   
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

inputs = Input(shape=(784,))
hidden = Dense(128, activation='relu')(inputs)
predictions = Dense(10, activation='softmax')(hidden)
model = Model(inputs=inputs, outputs=predictions)
model.summary()
model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])
model.save('mlp.h5')
history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1)
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

これを実行後にmlp.h5を開いてパラメータを見てみよう
ここでいうパラメータはWeightとBiasの2つ
隠れ層(Dense)を見るために

# coding utf-8
from __future__ import print_function
from keras.models import load_model
import numpy as np

print('Loading model and parameters...')
model = load_model('mlp.h5')
model.summary()
print('Print the parameters...')

layer = model.layers[1]     #summaryよりInput->[0], Dense->[1]なのでmodel.layers[1]
w = layer.get_weights()[0]
w = np.array(w)
b = layer.get_weights()[1]
b = np.array(b)
print('**Parameters shape**')
print('w.shape', w.shape)
print('b.shape', b.shape)

print('w = ', w)
print('b = ', b)

結果は、

Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 784)               0
_________________________________________________________________
dense_1 (Dense)              (None, 128)               100480
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
Print the parameters...
**Parameters shape**
w.shape (784, 128)
b.shape (128,)
w =  [[ 0.02914305  0.03145985  0.01455035 ...  0.06927236 -0.078566
  -0.07896576]
 [-0.0025496  -0.06011734 -0.04034768 ...  0.05541005 -0.06247988
   0.02232351]
 [-0.02663267  0.06812485  0.04543307 ... -0.06999144 -0.01722335
   0.00585063]
 ...
 [-0.00419274 -0.00560436  0.03638219 ...  0.00250916  0.06558482
   0.05288389]
 [-0.05047298  0.04109138  0.03568707 ...  0.01852474 -0.04672909
   0.063163  ]
 [-0.03907039  0.05963113 -0.07631446 ...  0.07864786  0.05980506
  -0.01691122]]
b =  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]

Biasが息してません...
#単純な話
*model.save( )*の位置が悪かったのが原因
*model.compile( )*は学習の設定、*model.fit( )*は学習の開始であるので、学習終わった後のデータをsaveするのが常
何も考えずに挿入した箇所が大幅に悪かったのね...
なので、

model.compile(...)
model.fit(...)
model.save(...)

上の順番に列記すればちゃんと?した数字が出ました(検証はこれから)
当たり前のことだけど、一応の自分用のメモとして書いておきます
因みに上のWeightは息しているようで、重みの初期値を割り振られただけの数値だったというオチですな
因みにの因みに、compileの前にsaveするとまた違う警告が来ちゃいます(対処法->コチラ)
皆様も私みたいなペーペーなことで泣かないように気を付けましょう...あうあう

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?