2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowで1次関数を予測(3) … ミニバッチ

Last updated at Posted at 2017-03-12

チュートリアルの1次関数の予測をネタとしていろいろ試してみます。

前回と同様、下記の直線を予測します。

y=0.5x+10

学習には、以下の種類があるらしいです。

訓練データひとつずつ使用・・・オンライン学習
訓練データすべてを使用・・・バッチ学習
訓練データから一部を選んで使用・・・ミニバッチ学習

チュートリアルでは訓練データすべてを使用していたので、バッチ学習でしょうか。
今回は、ミニバッチで学習してみます。

ソース

import tensorflow as tf
import numpy as np

np.random.seed(seed=32) # シード (種) を指定、発生する乱数をあらかじめ固定する(測定のため)

## 学習データ作成(x座標をランダムに作成(0.0〜1.0を100個)
x_data = np.random.rand(100).astype(np.float32)
# y座標を生成 (y = 0.5x + 10)
y_data = 0.5 * x_data + 10

## モデルを作成(y_data = W * x_data + b となる W と b の適正値を見つけます。
x_ = tf.placeholder(tf.float32, shape=[None, 1])
y_ = tf.placeholder(tf.float32, shape=[None, 1])
W = tf.Variable(tf.zeros([1, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.matmul(x_, W) + b

## 損失関数を作成(最小二乗法を使用
loss = tf.reduce_mean(tf.square(y - y_))

## 最適化アルゴリズムを指定(勾配降下法で損失関数を最小化
optimizer = tf.train.AdamOptimizer(1)
train = optimizer.minimize(loss)

## パラメータを初期化(Variableを使用する場合必要らしい
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

## 学習する
batch_size = 20 # バッチサイズ
for step in range(200):
    # ランダムなindexを作成(0〜学習データのサイズをバッチの個数分
    rnd_idx = np.random.randint(0, x_data.size, batch_size)
    x_batch_data = x_data[rnd_idx].reshape(batch_size, 1)
    y_batch_data = y_data[rnd_idx].reshape(batch_size, 1)
    sess.run(train, feed_dict={x_: x_batch_data, y_: y_batch_data})
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))

測定時の差分をなくすため最初にnp.random.seed(seed=32)
乱数の種を固定しています。

結果

バッチサイズ1(オンライン学習)

step W b
0 [ 1.] [ 0.99999988]
20 [ 7.06871223] [ 8.41462135]
40 [ 3.96863031] [ 8.00734711]
60 [ 1.58839536] [ 8.59520626]
80 [ 0.61793894] [ 9.73549461]
100 [ 0.40156782] [ 10.096591]
120 [ 0.43321908] [ 10.05368233]
140 [ 0.48158637] [ 10.00192833]
160 [ 0.49636325] [ 9.99834824]
180 [ 0.5013358] [ 9.99777508]

バッチサイズ5(ミニバッチ学習)

step W b
0 [ 1.] [ 0.99999988]
20 [ 5.80828667] [ 8.09599781]
40 [ 3.27099919] [ 9.08414268]
60 [ 1.33084488] [ 9.45936298]
80 [ 0.60776621] [ 9.88074303]
100 [ 0.43642431] [ 10.04355145]
120 [ 0.45859516] [ 10.02548981]
140 [ 0.49456775] [ 10.00613308]
160 [ 0.502379] [ 10.00013828]
180 [ 0.50175482] [ 9.9994936]

バッチサイズ20(ミニバッチ学習)

step W b
0 [ 1.] [ 1.]
20 [ 6.15904236] [ 8.29144859]
40 [ 3.32724142] [ 8.98768711]
60 [ 1.18161225] [ 9.33158493]
80 [ 0.53098202] [ 9.83389091]
100 [ 0.43710515] [ 10.02393532]
120 [ 0.46652189] [ 10.02241516]
140 [ 0.4972609] [ 10.00498962]
160 [ 0.50444597] [ 9.99849892]
180 [ 0.50116104] [ 9.99929047]

あまり変わりませんね。
このくらい簡単な予測だと効果ないのかもしれません。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?