24
26

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowで回帰をやってみる

Last updated at Posted at 2015-11-15

TensorFlowが出てきたのでMNIST以外で学習をやらせてみました。
回帰を行って相関をだしていきます。
(追記)データ数に対して中間層の出力が多すぎたため中間層の値を修正しました。

##概要

  • 使うデータはdiabetes
  • weightやbiaseはかなり適当に設定しています。

##ソースコード

import sklearn
import tensorflow as tf
from sklearn import datasets
import numpy as np

diabetes = datasets.load_diabetes()

#データをロード
print "load diabetes data"
data = diabetes["data"].astype(np.float32)
target = diabetes['target'].astype(np.float32).reshape(len(diabetes['target']), 1)
#学習データとテストデータに分割
N=342
x_train, x_test = np.vsplit(data, [N])
y_train, y_test = np.vsplit(target, [N])
N_test = y_test.size

x= tf.placeholder("float",shape=[None,10])

# 1層目 入力10 出力256
with tf.name_scope('l1') as scope:
    weightl1 = tf.Variable(tf.truncated_normal([10, 256], stddev=0.1),name="weightl1")
    biasel1 = tf.Variable(tf.constant(1.0, shape=[256]), name="biasel1")
    outputl1=tf.nn.relu(tf.matmul(x,weightl1) + biasel1)

# 2層目 入力256 出力1
with tf.name_scope('l2') as scope:
    weightl2 = tf.Variable(tf.truncated_normal([256, 1], stddev=0.1),name="weightl2")
    biasel2 = tf.Variable(tf.constant(1.0, shape=[1]), name="biasel2")
    outputl2=tf.nn.relu(tf.matmul(outputl1,weightl2) + biasel2)


    
"""
誤差計算のための関数
MSEで誤差を算出
"""
def loss(output):
    with tf.name_scope('loss') as scope:
        loss = tf.reduce_mean(tf.square(output - y_train))
    return loss


loss_op = loss(outputl2)
optimizer = tf.train.AdagradOptimizer(0.04)
train_step = optimizer.minimize(loss_op)

#誤差の記録
best = float("inf")

# 初期化
init_op = tf.initialize_all_variables()

with tf.Session() as sess:
    # initする
    sess.run(init_op)
    for i in range(20001):
        loss_train = sess.run(loss_op, feed_dict={x:x_train})
        sess.run(train_step, feed_dict={x:x_train})
        if loss_train < best:
            best = loss_train
            best_match = sess.run(outputl2, feed_dict={x:x_test})
        if i %1000 == 0:
            print "step {}".format(i)
            pearson = np.corrcoef(best_match.flatten(), y_test.flatten())
            print 'train loss = {} ,test corrcoef={}'.format(best,pearson[0][1])
               

##結果

load diabetes data
step 0
train loss = 29000.1777344 ,test corrcoef=0.169487254139
step 1000
train loss = 3080.2097168 ,test corrcoef=0.717823972634
step 2000
train loss = 2969.1887207 ,test corrcoef=0.72972180486
step 3000
train loss = 2938.4609375 ,test corrcoef=0.73486349373
step 4000
train loss = 2915.63330078 ,test corrcoef=0.737497869454
step 5000
train loss = 2896.14331055 ,test corrcoef=0.739029181368
step 6000
train loss = 2875.51708984 ,test corrcoef=0.74006814362
step 7000
train loss = 2856.36816406 ,test corrcoef=0.741115477047
step 8000
train loss = 2838.77026367 ,test corrcoef=0.742113966068
step 9000
train loss = 2822.453125 ,test corrcoef=0.743066699589
step 10000
train loss = 2807.88916016 ,test corrcoef=0.743988699821
step 11000
train loss = 2795.09057617 ,test corrcoef=0.744917437376
step 12000
train loss = 2783.8828125 ,test corrcoef=0.745871358086
step 13000
train loss = 2773.68457031 ,test corrcoef=0.747112534114
step 14000
train loss = 2764.80224609 ,test corrcoef=0.748115829411
step 15000
train loss = 2756.6628418 ,test corrcoef=0.748800330555
step 16000
train loss = 2749.1340332 ,test corrcoef=0.749471871992
step 17000
train loss = 2741.78881836 ,test corrcoef=0.750184567587
step 18000
train loss = 2734.56054688 ,test corrcoef=0.750722087518
step 19000
train loss = 2727.18579102 ,test corrcoef=0.751146409281
step 20000
train loss = 2719.29101562 ,test corrcoef=0.751330770654

##問題点

  • なぜか時々corrcoef=nanになる(調査中)。
    • outputl2が全て0.になっている。
      ->勾配爆発問題が起こっているかも。

かなり自由に設定ができるみたいですね。
まだまだ使いこなすまでに時間がかかりそうです。

24
26
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
26

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?