1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Google Colaboratory で Chainer を触ってみるvol.4 ~活性化関数を理解する~

Last updated at Posted at 2018-12-02

※この記事では, 一般的なニューラルネットワークの特徴, CNN のしくみなどは取り上げません. それらは各自で学習済みと想定しています.

前回

Google Colaboratory で Chainer を触ってみるvol.3 ~nodeを理解する~

進捗

隠れ層の数, 活性化関数, optimizer, 学習時のパラメータなどが学習時間と学習精度にどのような影響を及ぼすか, 「00 Colaboratory で Chainer を動かしてみよう」で実験中です. 今回は, 「活性化関数の影響」を確認しました.

理解したこと表

どの影響? 学習時間 学習精度 備考
隠れ層の数 + 0.3秒/1layer 8層でほぼ頭打ち -
隠れ層の node 数 ほぼ変わらず 500 node で +1.5% GPU 使えば何でもよさそう
活性化関数 +2.04s(selu) +2.7%(selu) パラメータをいじるともう少し精度上がる
epoch 数 - - -
batchsize - - -
optimizer - - -

活性化関数とは

node(ニューロン)への重みづけ入力とバイアスの総和が活性化(発火)するかどうかを決める関数です. 活性化した値が次の node への入力となります.

Chainer で用意されている活性化関数は, こちらにまとめられています. 今回は, Qiita - Chainerの活性化関数を可視化で取り上げられている9つの関数を試しました.

試行錯誤するコード

いつものやつです. そのほかのコードは, Vo.1で投稿したものと同じです.

.py
# class_model.py
import chainer.functions as F
import chainer.links as L
from chainer import Chain
class MLPNew(Chain):
  
  def __init__(self):
      super(MLPNew, self).__init__()
      with self.init_scope():
          # Add more layers?
          self.l1 = L.Linear(784, 200) # Increase output node as (784, 300)?
          self.l2 = L.Linear(200, 200) # Increase nodes as (300, 300)?
          self.l3 = L.Linear(200, 10)  # Increase nodes as (300, 10)?
  
  def forward(self, x):
      h1 = F.tanh(self.l1(x))  # Replace F.tanh with F.sigmoid or F.relu?
      h2 = F.tanh(self.l2(h1)) # Replace F.tanh with F.sigmoid or F.relu?
      y = self.l3(h2)
      return y

# do_train_and_validate.py
device = 0
n_epoch = 5     # Add more epochs?
batchsize = 256 # Increase/Decrease mini-batch size?

model = MLPNew()
classifier_model = L.Classifier(model)
optimizer = optimizers.SGD() # Default SGD(). Use other optimizer, Adam()?(Are there Momentum and AdaGrad?)

train_and_validate(
    classifier_model, optimizer, train, validation, n_epoch, batchsize, device)

今回は, 下記部分の tanh を各活性化関数に変更すれば試せます.

.py
      h1 = F.tanh(self.l1(x))  # Replace F.tanh with F.sigmoid or F.relu?
      h2 = F.tanh(self.l2(h1)) # Replace F.tanh with F.sigmoid or F.relu?

結果詳細

デフォルト(HandsOn 通りの設定)

  • 隠れ層 : 3
  • 隠れ層の node 数 : 200
  • 活性化関数 : tanh ←今回はこれを変更する
  • epoch 数 : 5
  • batchsize : 256
  • optimizer : SGD
活性化関数 合計学習時間 学習精度
tanh(デフォルト) 16.0598(基準) 0.78984374(基準)
clipped_relu 16.2518(+0.2s) 0.7798828(+1%)
hard_sigmoid 15.7152(-0.34s) 0.46464843(-32.5%)
sigmoid 15.8886(-0.17s) 0.47636718(-31.3)
leaky_relu 15.4042(-0.65s) 0.78164065(-0.8%)
relu 16.0199(-0.04s) 0.7799805(-1%)
elu 16.5385(+0.48s) 0.7970703(+0.8%)
selu 18.0986(+2.04s) 0.8165039(+2.7%)
softplus 15.8983(-0.16s) 0.74472654(-4.5%)
selu 関数が一番良好な結果と思われます. sigmoid 関数はダメそうですね.
あと, 表に書き忘れていますが, 学習結果とテスト結果に大きな差はないので, 過学習は発生してません.

selu のパラメータをいじってみる

selu に関する公式の説明には, α や scale について設定があるので, こちらをいじって結果を比較しました.

.py
  def forward(self, x):
      h1 = F.selu(self.l1(x),alpha=1.6732632423543772, scale=1.0507009873554805)
      h2 = F.selu(self.l2(h1),alpha=1.6732632423543772, scale=1.0507009873554805)
      y = self.l3(h2)
      return y
α scale 合計学習時間 学習精度
未設定 未設定 18.0986 0.8165039
1.6732632423543772 1.0507009873554805 17.8415 0.8162109
2.4 1.0507009873554805 17.9607 0.83095706
1.6732632423543772 1.5 17.8783 0.8292969
2.4 1.5 17.695 0.8381836
α, scale どちらも大きい方が精度は良くなりそうです. これらパラメータは学習時間に影響はほぼなさそうです. ほんとは関数の定義をきちんと理解すべきですが, 現時点はここまでにとどめます. selu 以外の関数にもパラメータはありますので, チューニングすれば精度が逆転するかもしれませんね.

まとめ

活性化関数を変更すると, 学習時間と精度に差が出ることがわかりました. ただ, sigmoid 関数以外であれば顕著な差はなさそうです. 各関数にはそれぞれパラメータがありますので, その値をチューニングすればもっと精度が上がるかもしれません.

次回

DeepLearning のエポック数とバッチサイズを理解する.

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?