LoginSignup
44
38

More than 5 years have passed since last update.

chainerを使って回帰してみると、ちょっとはまる

Last updated at Posted at 2015-06-20

概要

  • はじめまして、qiitaに初投稿です。至らぬところが多いと思います。
  • ニューラルネットワークのフレームワークのサンプルコードは大体分類問題で、回帰の例があまりないのでトライしてみました。
  • はまる箇所がありました。上手な書き方を知りたいので、投稿します。

  • theanoだと同等の事をするのに1000行以上書きましたが、chainerだと98行でした。chainerに移行します。

  • GPUでも問題なく動きました。

  • 結論: chainerすごい

ソースコード

はまった箇所

  • 回帰問題の場合、正解データのshapeが(データ数,)だと,batchsizeが2以上の時になぜかnumpyのbroadcastingがうまくいきません。
  • reshapeで(データ数,1)にするとbroadcastingできるようになりました。
  • theanoでも全く同じ現象が起こるのでnumpyの問題だと思います。
#targetは正解データ
#mnist.pyだと↓のような書き方になっている 
target = diabetes['target'].astype(np.float32) #これだとミニバッチの数が2以上だと動かない

↑だと↓のようなValueErrorがでます(13というのはbatchsize)。

ValueError: non-broadcastable output operand with shape (1,30) doesn't match the broadcast shape (13,30)

試行錯誤した結果これで動きました。

#reshapeしないといけない
target = diabetes['target'].astype(np.float32).reshape(len(diabetes['target']), 1)
  • 醜いかつ、相関係数が計算しにくいなど、不便な点があります。
  • もう少しうまく書けないでしょうか?

example/mnist.pyからの変更点

環境

chainerのインストールに関して

  • 最初はpip install chainerでインストールした(6月19日)chainerを使ってましたが、GPU上でadadeltaが動きませんでした(一文抜けてた)。
  • github上では修正されていたので、git cloneして、python setup.py install
  • pip install chainer-cuda-deps

変更点

データセット

  • scikit-learnのdiabetesデータセットを利用しました。
  • 入力データ: 10次元で442サンプル(小さい上に少なくてごめんなさい)

モデル

  • ネットワークのアーキテクチャ
    • 10(input)-30-30-1(output)
n_units   = 30
model = FunctionSet(l1=F.Linear(10, n_units),                                 
                    l2=F.Linear(n_units, n_units),
                    l3=F.Linear(n_units, 1)) 

学習率の調整

  • Adadelta(Adamよりも良かった)
optimizer = optimizers.AdaDelta(rho=0.9)

誤差関数と評価手法

  • 誤差関数は(ミニバッチ内の)平均二乗誤差
# Neural net architecture
def forward(x_data, y_data, train=True):
    x, t = Variable(x_data), Variable(y_data)
    h1 = F.dropout(F.relu(model.l1(x)),  train=train)
    h2 = F.dropout(F.relu(model.l2(h1)), train=train)
    y  = model.l3(h2)
    # 平均二乗誤差と予測結果を返す
    # 予測結果を返すのは、後で予測結果と正解データで相関係数を計算するから
    return F.mean_squared_error(y, t), y
  • 正解率は出せないので、分類と同じ評価はできません。
  • 誤差だけだとわかりにくいので、普通はR2値を利用する?
  • 個人的事情で、予測値と正解データの相関係数で評価しました。

  • はまった箇所(上参照)のせいで相関係数を計算する行が超絶汚くなります。

pearson = np.corrcoef(np.asarray(preds).reshape(len(preds),), np.asarray(y_test).reshape(len(preds),))

感想

  • theanoに比べてデバッグが本当に楽でした(theanoはどこでエラーが起きているのか本当わかりにくい)。
  • 感動するレベルでコードが短くなりました。
  • weight decayはできるみたいですけど、lassoは実現できるのでしょうか?
44
38
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
44
38