チュートリアルの1次関数の予測をネタとしていろいろ試してみます。
前回と同様、下記の直線を予測します。
y=0.5x+10
学習には、以下の種類があるらしいです。
訓練データひとつずつ使用・・・オンライン学習
訓練データすべてを使用・・・バッチ学習
訓練データから一部を選んで使用・・・ミニバッチ学習
チュートリアルでは訓練データすべてを使用していたので、バッチ学習でしょうか。
今回は、ミニバッチで学習してみます。
ソース
import tensorflow as tf
import numpy as np
np.random.seed(seed=32) # シード (種) を指定、発生する乱数をあらかじめ固定する(測定のため)
## 学習データ作成(x座標をランダムに作成(0.0〜1.0を100個)
x_data = np.random.rand(100).astype(np.float32)
# y座標を生成 (y = 0.5x + 10)
y_data = 0.5 * x_data + 10
## モデルを作成(y_data = W * x_data + b となる W と b の適正値を見つけます。
x_ = tf.placeholder(tf.float32, shape=[None, 1])
y_ = tf.placeholder(tf.float32, shape=[None, 1])
W = tf.Variable(tf.zeros([1, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.matmul(x_, W) + b
## 損失関数を作成(最小二乗法を使用
loss = tf.reduce_mean(tf.square(y - y_))
## 最適化アルゴリズムを指定(勾配降下法で損失関数を最小化
optimizer = tf.train.AdamOptimizer(1)
train = optimizer.minimize(loss)
## パラメータを初期化(Variableを使用する場合必要らしい
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
## 学習する
batch_size = 20 # バッチサイズ
for step in range(200):
# ランダムなindexを作成(0〜学習データのサイズをバッチの個数分
rnd_idx = np.random.randint(0, x_data.size, batch_size)
x_batch_data = x_data[rnd_idx].reshape(batch_size, 1)
y_batch_data = y_data[rnd_idx].reshape(batch_size, 1)
sess.run(train, feed_dict={x_: x_batch_data, y_: y_batch_data})
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
測定時の差分をなくすため最初にnp.random.seed(seed=32)
で
乱数の種を固定しています。
結果
バッチサイズ1(オンライン学習)
step W b
0 [ 1.] [ 0.99999988]
20 [ 7.06871223] [ 8.41462135]
40 [ 3.96863031] [ 8.00734711]
60 [ 1.58839536] [ 8.59520626]
80 [ 0.61793894] [ 9.73549461]
100 [ 0.40156782] [ 10.096591]
120 [ 0.43321908] [ 10.05368233]
140 [ 0.48158637] [ 10.00192833]
160 [ 0.49636325] [ 9.99834824]
180 [ 0.5013358] [ 9.99777508]
バッチサイズ5(ミニバッチ学習)
step W b
0 [ 1.] [ 0.99999988]
20 [ 5.80828667] [ 8.09599781]
40 [ 3.27099919] [ 9.08414268]
60 [ 1.33084488] [ 9.45936298]
80 [ 0.60776621] [ 9.88074303]
100 [ 0.43642431] [ 10.04355145]
120 [ 0.45859516] [ 10.02548981]
140 [ 0.49456775] [ 10.00613308]
160 [ 0.502379] [ 10.00013828]
180 [ 0.50175482] [ 9.9994936]
バッチサイズ20(ミニバッチ学習)
step W b
0 [ 1.] [ 1.]
20 [ 6.15904236] [ 8.29144859]
40 [ 3.32724142] [ 8.98768711]
60 [ 1.18161225] [ 9.33158493]
80 [ 0.53098202] [ 9.83389091]
100 [ 0.43710515] [ 10.02393532]
120 [ 0.46652189] [ 10.02241516]
140 [ 0.4972609] [ 10.00498962]
160 [ 0.50444597] [ 9.99849892]
180 [ 0.50116104] [ 9.99929047]
あまり変わりませんね。
このくらい簡単な予測だと効果ないのかもしれません。