search
LoginSignup
6

More than 3 years have passed since last update.

posted at

updated at

多項式曲線フィッティングをTensorFlowで実装する[PRML]

はじめに

単純な回帰問題を解くための手法のひとつである多項式曲線フィッティングをTensorFlowを用いて実装しました。最終的なスクリプトはGithubで公開しているのそちらを参照してください。

環境

  • OS : Windows10 64bit
  • IDE : JetBrains Pycharm x64
  • Python 3.6
  • tensorflow 1.9.0

問題設定

C.M.Bishopの「パターン認識と機械学習 上」を元に問題設定を行います。
関数$\sin(2\pi x)$にランダムなノイズを加えることで人工的な訓練集合を生成していきます。今回はこの訓練集合を利用して新たな入力に対する目的変数を予測していく回帰問題を多項式曲線フィッティングによって解いていきます。
色々書きましたが、三角関数を多項式によって近似するという問題となります。

観測点と目標データ集合の作成

$N$個の観測点$x$を並べた$\boldsymbol{x}\equiv(x_1, \cdots, x_N)^T$とそれぞれの観測点に対応する観測値$t$を並べた$\boldsymbol{t}\equiv(t_1, \cdots, t_N)^T$を与えます。そして$\sin(2\pi x)$の関数値を計算した後に、ガウス分布に従う小さなランダムノイズを加えることで目標データ集合$\boldsymbol{t}$を生成します。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 観測点(N=10)の作成
N = 10
x_vals = np.reshape(np.linspace(0., 1., N), newshape=[N, 1])
c_vals = np.reshape(np.linspace(0., 1., 100), newshape=[100, 1])    # 検証用の連続点
x = tf.placeholder(dtype=tf.float32, shape=[None, 1])

# 目標データ集合の作成
sin = tf.sin(2.*np.pi*x)
noise = tf.random_normal([N, 1], stddev=0.1)
t = sin + noise

# 訓練データ集合の表示
with tf.Session() as sess:
    plt.plot(c_vals, sess.run(sin, feed_dict={x: c_vals}), color='green', label='sin(2πx)')
    plt.scatter(x_vals, sess.run(t, feed_dict={x: x_vals}), label='target')
    plt.title('Set of training data (N=10)')
    plt.xlabel('x')
    plt.ylabel('t')
    plt.legend()
    plt.show()

多項式の設定

今回は以下のような多項式を使ってデータへのフィッティングを行っていきます。$\boldsymbol{w}\equiv(w_1, \cdots, w_M)$は$M$次多項式における各項の係数です。

$$y(x, \boldsymbol{w})=w_0 + w_1x + w_2x^2 + \cdots + w_Mx^M=\sum_{j=0}^Mw_jx^j$$
Summationを消すためにベクトルの内積の形にします。

$$y(\boldsymbol{x},\boldsymbol{w})=〈\boldsymbol{x},\boldsymbol{w}〉$$

今回はPRMLに沿って3次多項式を用いています。色々検証したい方は$M$の値を変えて下さい。


# 多項式の設定
M = 3
w = tf.Variable(tf.random_normal([M+1, 1]))
xj = tf.pow(x, np.arange(M+1))
y = tf.matmul(xj, w)

損失関数と最適化アルゴリズムの設定

損失関数は$\boldsymbol{w}$を任意に固定した時の関数$y(x, \boldsymbol{w})$の値と目標データ点$t_n$との間の誤差を示す関数です。今回は単純で広く用いられている二乗和誤差(Sum-of-squares Error)を用いることにします。二乗和誤差は以下の式で書けます。

$$E(\boldsymbol{w})=\frac{1}{2}\sum_{n=1}^N \left( y (x,\boldsymbol{w}) - t_n \right)^2$$

この損失関数を最適化アルゴリズムを用いて最小化することで私たちの目的は達成されます。ここでは最適化アルゴリズムについて詳しく説明しませんが、Adam(Adaptive Moment Estimation)を用いて損失関数の最小化を行っていきます。

# 損失関数の設定
loss = tf.div(tf.reduce_sum(tf.square(tf.subtract(y, t))), 2.)

# 最適化アルゴリズムの設定
learning_rate = 0.50
opt = tf.train.AdamOptimizer(learning_rate)
train = opt.minimize(loss)

学習と結果

ここまでくれば後は学習を行うだけです。今回は2000回学習させています。

loss_vec = []   # 損失を保存しておくためのリスト

with tf.Session() as sess:
    # 変数初期化
    sess.run(tf.global_variables_initializer())

    # 訓練開始
    epoch = 2000
    for step in range(epoch):
        sess.run(train, feed_dict={x: x_vals})
        tmp_loss = sess.run(loss, feed_dict={x: x_vals})
        loss_vec.append(tmp_loss)
        if step % (epoch/10) == 0:
            print('step[{}]  loss : {}'.format(step, tmp_loss))

    # 結果の表示
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(15, 4))
    ax1.plot(c_vals, sess.run(sin, feed_dict={x: c_vals}), color='green', label='sin(2πx)')
    ax1.plot(c_vals, sess.run(y, feed_dict={x: c_vals}), label='{} order polynomial'.format(M))
    ax1.scatter(x_vals, sess.run(t, feed_dict={x: x_vals}), label='target')
    ax1.set_xlabel('x')
    ax1.set_ylabel('t')
    ax1.legend()
    ax2.plot(loss_vec)
    ax2.set_xlabel('Generation')
    ax2.set_ylabel('loss')
    fig.show()

実行結果は以下のようになります。

step[0]  loss : 5.928642272949219
step[200]  loss : 0.7202547192573547
step[400]  loss : 0.4076552391052246
step[600]  loss : 0.3739248514175415
step[800]  loss : 0.21431005001068115
step[1000]  loss : 0.1080390065908432
step[1200]  loss : 0.09304159134626389
step[1400]  loss : 0.6156651377677917
step[1600]  loss : 0.23187297582626343
step[1800]  loss : 0.08678792417049408

学習が進むにつれて損失が減少していき、2000回学習した時点で多項式が$\sin(2\pi x)$に近似していることが分かります。

おわりに

今回は多項式曲線フィッティングをTensorFlowで実装しました。$M$の値を変更すれば分かるのですが、次数を増やせば理想的な回帰ができるという訳ではありません。これは過学習が起こるためです。過学習をある程度防ぐための手法として正則化などありますが、それはまたいずれ。

参考

[1] C.M.ビショップ(2011) 「パターン認識と機械学習 上 : ベイズ理論による統計的予測」丸善.

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
6