LoginSignup
8
6

More than 5 years have passed since last update.

TensorFlowで平均と分散の最尤推定

Posted at

Tensorflowが機械学習ライブラリという誤解がまかり通っているので,理解を深めるために平均と分散を求めてみた.

  • 平均50標準偏差10の正規乱数を100個生成.
  • 学習係数をいつくか変更

平均だけならTensorFlowで数列の平均を求めてみた
を参照.

Gradient decent

import matplotlib.pylab as plt
%matplotlib inline
import numpy as np
import tensorflow as tf

x_train = np.random.randn(100) * 10 + 50

n_itr = 10000

m = tf.Variable([30.0], tf.float32)
s = tf.Variable([3.0], tf.float32)
x = tf.placeholder(tf.float32)
N = tf.count_nonzero(x, dtype=tf.float32)

loss = N / 2 * tf.log(2 * np.pi * s**2) + 1.0 / (2 * s**2) * tf.reduce_sum(tf.square(x - m))

for lr in [0.1, 0.01, 0.001]:

    optimizer = tf.train.GradientDescentOptimizer(lr)
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    est_m = []
    est_s = []
    for i in range(n_itr):
        sess.run(train, {x:x_train})
        est_m.append(sess.run([m]))
        est_s.append(sess.run([s]))

    est_m = np.array(est_m)
    est_s = np.array(est_s)
    plt.plot(est_s.reshape(n_itr)[::100], est_m.reshape(n_itr)[::100], marker=".", label="lr={}".format(lr))

plt.title("batch gradient decent")
plt.xlabel("std")
plt.ylabel("mean")
plt.legend()
plt.show();

Unknown-5.png

RMS Prop

n_itr = 1000

for lr in [1, 0.5, 0.1]:

    optimizer = tf.train.RMSPropOptimizer(lr)
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    est_m = []
    est_s = []
    for i in range(n_itr):
        sess.run(train, {x:x_train})
        est_m.append(sess.run([m]))
        est_s.append(sess.run([s]))

    est_m = np.array(est_m)
    est_s = np.array(est_s)
    plt.plot(est_s.reshape(n_itr)[::10], est_m.reshape(n_itr)[::10], marker=".", label="lr={}".format(lr))

plt.title("batch RMS Prop")
plt.xlabel("std")
plt.ylabel("mean")
plt.legend()
plt.show();

Unknown-4.png

Adam

n_itr = 1000

for lr in [5, 1, 0.1, 0.01]:

    optimizer = tf.train.AdamOptimizer(lr)
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    est_m = []
    est_s = []
    for i in range(n_itr):
        sess.run(train, {x:x_train})
        est_m.append(sess.run([m]))
        est_s.append(sess.run([s]))

    est_m = np.array(est_m)
    est_s = np.array(est_s)
    plt.plot(est_s.reshape(n_itr)[::10], est_m.reshape(n_itr)[::10], marker=".", label="lr={}".format(lr))

plt.title("batch Adam")
plt.xlabel("std")
plt.ylabel("mean")
plt.legend()
plt.show();

Unknown-3.png

AdaGrad

n_itr = 1000

for lr in [20, 10, 5, 1]:

    optimizer = tf.train.AdagradOptimizer(lr)
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    est_m = []
    est_s = []
    for i in range(n_itr):
        sess.run(train, {x:x_train})
        est_m.append(sess.run([m]))
        est_s.append(sess.run([s]))

    est_m = np.array(est_m)
    est_s = np.array(est_s)
    plt.plot(est_s.reshape(n_itr)[::10], est_m.reshape(n_itr)[::10], marker=".", label="lr={}".format(lr))

plt.title("batch AdaGrad")
plt.xlabel("std")
plt.ylabel("mean")
plt.legend()
plt.show();

Unknown-2.png

AdaDelta

n_itr = 1000

for lr in [5000, 1000, 100]:

    optimizer = tf.train.AdadeltaOptimizer(lr)
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    est_m = []
    est_s = []
    for i in range(n_itr):
        sess.run(train, {x:x_train})
        est_m.append(sess.run([m]))
        est_s.append(sess.run([s]))

    est_m = np.array(est_m)
    est_s = np.array(est_s)
    plt.plot(est_s.reshape(n_itr)[::10], est_m.reshape(n_itr)[::10], marker=".", label="lr={}".format(lr))

plt.title("batch AdaDelta")
plt.xlabel("std")
plt.ylabel("mean")
plt.legend()
plt.show();

Unknown-1.png

8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6