2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Google Colaboratory で Chainer を触ってみるvol.7 ~LeNet を嗜む~

Last updated at Posted at 2018-12-12

前回

Google Colaboratory で Chainer を触ってみるvol.6 ~optimizerの理解、いままでのまとめ~

進捗

00 Colaboratory で Chainer を動かしてみよう」の最後課題に挙げられていた LeNet-5 を試しました.

LeNet-5 とは

LeNet5 とは, 5層の LeNet です. 3つの畳み込み(convolution)層と, 2つの全結合層を持つ, 計5層のネットワークです.
image.png
(HandsOn のページから転載しました.)

LeNet は手書き数字認識を行うネットワークとして 1998 年に提案された, CNN の元祖です. 畳み込み層, プーリング層(要素を間引く層)を連続させ, 最後に全結合層を経て結果が出力されます. 昔は LeNet の活性化関数に sigmoid が使われていましたが, 最近では relu が用いられるそうです.

LeNet-5 で Fashion MNIST をやってみる

サンプルコードを動かしつつ, 精度を 90% 以上にできないか試してみました. 使用したコードは, この記事の最後に記載しています.

  • 精度 90% 以上
  • 学習時間 200sec 以内

結果

デフォルト設定(HandsOn 通りの設定)

  • 活性化関数 : sigmoid
  • epoch 数 : 10
  • batchsize : 256
  • optimizer : Adam
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time
1           1.83379     0.349968       1.16711        0.570508           3.91257       
2           0.98471     0.628926       0.89656        0.668945           7.99946       
3           0.810149    0.703085       0.777537       0.70498            11.9553       
4           0.711494    0.733          0.698332       0.732031           16.0126       
5           0.647929    0.74988        0.652298       0.74541            20.0139       
6           0.608307    0.76232        0.612736       0.754687           24.1201       
7           0.575744    0.772939       0.583578       0.767773           28.1152       
8           0.551792    0.78135        0.559339       0.773145           32.1578       
9           0.528406    0.790966       0.542836       0.788965           36.1562       
10          0.506318    0.800163       0.51899        0.789941           40.2234       

Test accuracy: 0.7987305

80% 行くか行かないか.

変更1 : batchsize

  • 活性化関数 : sigmoid
  • epoch 数 : 10
  • batchsize : 128 ←
  • optimizer : Adam
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time
1           1.63544     0.404412       0.946665       0.662579           7.17152       
2           0.779716    0.711297       0.703671       0.731013           14.6435       
3           0.640226    0.753405       0.620902       0.756527           22.077        
4           0.573759    0.775196       0.578544       0.768987           29.395        
5           0.530182    0.79144        0.529941       0.790645           36.7229       
6           0.497212    0.80621        0.50374        0.803699           44.0955       
7           0.47194     0.817915       0.479635       0.815071           51.6966       
8           0.450012    0.829187       0.458499       0.828422           59.1986       
9           0.43215     0.839494       0.447579       0.832872           66.6459       
10          0.416274    0.845908       0.430306       0.839695           74.0034       

Test accuracy: 0.84169924

batchsize を減らすといい感じで精度が上がりました.

変更2 : 活性化関数

  • 活性化関数 : relu ←
  • epoch 数 : 10
  • batchsize : 256
  • optimizer : Adam
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time
1           0.835276    0.696309       0.615126       0.760254           4.03892       
2           0.534279    0.79988        0.525212       0.798535           8.42688       
3           0.466187    0.83143        0.484503       0.814062           12.7898       
4           0.430691    0.844348       0.436197       0.837988           17.2373       
5           0.397923    0.857312       0.415125       0.844238           21.638        
6           0.379096    0.863762       0.383555       0.861328           26.017        
7           0.358667    0.871851       0.376659       0.859961           30.3771       
8           0.347551    0.8748         0.367091       0.866992           34.8929       
9           0.330936    0.880729       0.36706        0.864648           39.1901       
10          0.321007    0.883311       0.346736       0.873633           43.4323       

Test accuracy: 0.8751953

活性化関数を変えるだけで, だいぶ精度が上がりました.

変更3 : 活性化関数と batchsize

  • 活性化関数 : relu ←
  • epoch 数 : 10
  • batchsize : 128 ←
  • optimizer : Adam
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time
1           0.718779    0.736833       0.54409        0.797765           7.3917        
2           0.469592    0.829384       0.446577       0.839794           15.0238       
3           0.406688    0.853546       0.393267       0.857595           22.5937       
4           0.373509    0.866188       0.380327       0.862045           30.3102       
5           0.348192    0.874321       0.363642       0.86521            37.9552       
6           0.330103    0.879647       0.366197       0.865605           45.6231       
7           0.315076    0.88523        0.341243       0.87144            53.2682       
8           0.301359    0.890264       0.33667        0.876879           61.0204       
9           0.288146    0.89538        0.32718        0.878659           68.7567       
10          0.279945    0.897498       0.314154       0.88301            76.5967       

Test accuracy: 0.88427734

精度への影響は, やはり活性化関数が支配的ですね. さらにエポック数やバッチサイズを変えても, Test accuracy は 88% 以降ほぼ横ばいでした. 90% 以上の道は近そうで遠い...このあたりの最後のチューニングがノウハウなんでしょうかね.

まとめ

LeNet は元祖 CNN. 精度を上げる最後のチューニングには, 何らかのノウハウが必要っぽい.

次回

01 Chainerの基本的な使い方を学んでみようをやる.

参考:試したコード

LeNet5.py
# colab の cuda に応じて、いい感じに chainer と Cupy をインストールするコマンド
!curl https://colab.chainer.org/install | sh -

# chainer のインストール確認
!python -c "import chainer; chainer.print_runtime_info()"

# graphviz のインストール
!apt -y -qq install graphviz > /dev/null 2> /dev/null
!pip install pydot

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import chainer

# ラベル ID (label)からラベル名を取得する関数
LABEL_NAMES = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot"
]

def get_label_name(label):
  return LABEL_NAMES[label]

# fashion MNIST のデータを取ってきて学習用, validation 用, test 用に分ける
from chainer.datasets.fashion_mnist import get_fashion_mnist

train, test = get_fashion_mnist(withlabel=True, ndim=1)
train, validation = chainer.datasets.split_dataset_random(train, 50000, seed=0)

# 学習
from chainer import optimizers, training
from chainer.training import extensions

def train_and_validate(
      model, optimizer, train, validation, n_epoch, batchsize, device):
  # 1.device が gpu なら、gpu にモデルデータを転送する
  if device >= 0:
    model.to_gpu(device)
  
  # 2. optimizer を設定する
  optimizer.setup(model)
  
  # 3. Dataset から Iterator を作成する
  train_iter = chainer.iterators.SerialIterator(train, batchsize)
  validation_iter = chainer.iterators.SerialIterator(
      validation, batchsize, repeat=False, shuffle=False)
  
  # 4. Updater, Trainer を作成する
  updater = training.StandardUpdater(train_iter, optimizer, device=device)
  trainer = chainer.training.Trainer(updater, (n_epoch, "epoch"), out="out")
  
  # 5. Trainer の機能を拡張する
  trainer.extend(extensions.LogReport())
  trainer.extend(extensions.Evaluator(validation_iter, model, device=device), name="val")
  trainer.extend(extensions.PrintReport(
      ["epoch", "main/loss", "main/accuracy", "val/main/loss", "val/main/accuracy", "elapsed_time"]))
  trainer.extend(extensions.PlotReport(
      ["main/loss", "val/main/loss"], x_key="epoch", file_name="loss.png"))
  trainer.extend(extensions.PlotReport(
      ["main/accuracy", "val/main/accuracy"], x_key="epoch", file_name="accuracy.png"))
  trainer.extend(extensions.dump_graph("main/loss"))
  
  # 6. 訓練を開始する
  trainer.run()

# グラフ表示
import pydot
from IPython.display import Image, display

def show_graph():
    graph = pydot.graph_from_dot_file('out/cg.dot') # load from .dot file
    graph[0].write_png('graph.png')
    display(Image('graph.png', width=600, height=600))

def show_loss_and_accuracy():
    display(Image(filename="out/loss.png"))
    display(Image(filename="out/accuracy.png"))

# test 結果計算
def show_test_performance(model, test, device, batchsize=256):
    if device >= 0:
        model.to_gpu()
    test_iter = chainer.iterators.SerialIterator(
        test, batchsize, repeat=False, shuffle=False
    )
    test_evaluator = extensions.Evaluator(test_iter, model, device=device)
    results = test_evaluator()
    print("Test accuracy:", results["main/accuracy"])

# LeNet-5
class LeNet5(Chain):
  def __init__(self):
    super(LeNet5, self).__init__()
    with self.init_scope():
      self.conv1 = L.Convolution2D(
          in_channels=1, out_channels=6, ksize=5, stride=1, pad=0)
      self.conv2 = L.Convolution2D(
          in_channels=6, out_channels=16, ksize=5, stride=1, pad=0)
      self.conv3 = L.Convolution2D(
          in_channels=16, out_channels=120, ksize=4, stride=1, pad=0)
      self.fc4 = L.Linear(None, 84)
      self.fc5 = L.Linear(84, 10)
 
  def forward(self, x):
#     h = F.sigmoid(self.conv1(x.reshape((-1, 1, 28, 28))))
#     h = F.max_pooling_2d(h, ksize=2, stride=2)
#     h = F.sigmoid(self.conv2(h))
#     h = F.max_pooling_2d(h, ksize=2, stride=2)
#     h = F.sigmoid(self.conv3(h))
#     h = F.sigmoid(self.fc4(h))

    h = F.relu(self.conv1(x.reshape((-1, 1, 28, 28))))
    h = F.max_pooling_2d(h, ksize=2, stride=2)
    h = F.relu(self.conv2(h))
    h = F.max_pooling_2d(h, ksize=2, stride=2)
    h = F.relu(self.conv3(h))
    h = F.relu(self.fc4(h))

    return self.fc5(h)

# 学習と test 実行
device = 0
n_epoch = 10
batchsize = 256

model = LeNet5()
classifier_model = L.Classifier(model)
optimizer = optimizers.Adam()
train_and_validate(
    classifier_model, optimizer, train, validation, n_epoch, batchsize, device)
show_test_performance(classifier_model, test, device)
show_loss_and_accuracy()
show_graph()
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?