LoginSignup
0
1

More than 5 years have passed since last update.

TensorFlowでkeras その2

Last updated at Posted at 2017-09-02

概要

TnsorFlowでkerasやってみた。
xorを学習してみた。
save,loadしてみた。

環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.2

サンプルコード

import numpy as np 
from tensorflow.contrib.keras.python.keras.models import Sequential, model_from_json
from tensorflow.contrib.keras.python.keras.layers.core import Dense, Activation
from tensorflow.contrib.keras.python.keras.optimizers import SGD
import os.path

X = np.array([[0, 0], [0, 1.0], [1.0, 0], [1.0, 1.0]])
y = np.array([[1.0], [0], [0], [1.0]])

model = Sequential()
model.add(Dense(8, input_dim = 2))
model.add(Activation('tanh'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
sgd = SGD(lr = 0.1)
model.compile(loss = 'binary_crossentropy', optimizer = sgd)
model.fit(X, y, batch_size = 1, epochs = 300)
model.summary();
print (model.predict_proba(X))

print ('save model')
json_string = model.to_json()
open(os.path.join('./', 'xor_model.json'), 'w').write(json_string)
yaml_string = model.to_yaml()
open(os.path.join('./', 'xor_model.yaml'), 'w').write(yaml_string)
print ('save weights')
model.save_weights(os.path.join('./', 'xor_model.h5'))

model1 = model_from_json(open(os.path.join('./', 'xor_model.json')).read())
model1.load_weights(os.path.join('./', 'xor_model.h5'))
model1.summary();
model1.compile(loss = 'binary_crossentropy', optimizer = 'sgd')
score = model1.evaluate(X, y, verbose = 0)
print (score)
print (model1.predict_proba(X))

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1