LoginSignup
2
5

More than 5 years have passed since last update.

Google Colaboratory で Chainer を触ってみるvol.1 ~HandsOn 通りに動かす~

Last updated at Posted at 2018-11-27

はじめに

無料でGPU環境が利用できる Google Colaboratory で Chainer を使いながら, Deeplearning を勉強しています.
Google Colaboratory を触ってみる」が親ページです.

目標

Chainer Colab Notebook の HandsOn を一通りやる

進捗

00 Colaboratory で Chainer を動かしてみよう」を実践中. 一通り学習, 検証, テストをするためのコードを書きました. 内容自体は, ほぼコピペです.

動かしたコードの羅列

install.py
# google colab で chainer などを使うための設定

# colab の cuda に応じて、いい感じに chainer と Cupy をインストールするコマンド
!curl https://colab.chainer.org/install | sh -

# chainer のインストール確認コマンド
!python -c "import chainer; chainer.print_runtime_info()"

# graphviz のインストール
!apt -y -qq install graphviz > /dev/null 2> /dev/null
!pip install pydot

# 必要なライブラリをまとめてインポートしておく
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import chainer

chainer は Colab にデフォルトで入っていないので, インストールします. ここでは graphviz も使うので, 一緒にインストールします. 先頭に "!" を付けると, コマンドラインの命令を実行することができます.

一定時間何もしていなかったり, PC をシャットダウンして google colab サーバとの接続が切れると, 一からインストールし直す必要があるようです.


get_label.py
# ラベル ID (label)からラベル名を取得する関数
LABEL_NAMES = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot"
]

def get_label_name(label):
  return LABEL_NAMES[label]

get_fashion_mnist.py
# 学習, test に使うデータを取得.
from chainer.datasets.fashion_mnist import get_fashion_mnist
train, test = get_fashion_mnist(withlabel=True, ndim=1)

# 学習データを, 学習用と検証用に分ける
train, validation = chainer.datasets.split_dataset_random(train, 50000, seed=0)

データは, 学習用(受験勉強), 検証用(模試), テスト用(受験本番)の三つの用途に分割します.


func_train_and_validation.py
from chainer import optimizers, training
from chainer.training import extensions

def train_and_validate(
      model, optimizer, train, validation, n_epoch, batchsize, device):
  # 1.device が gpu なら、gpu にモデルデータを転送する
  if device >= 0:
    model.to_gpu(device)

  # 2. optimizer を設定する
  optimizer.setup(model)

  # 3. Dataset から Iterator を作成する
  train_iter = chainer.iterators.SerialIterator(train, batchsize)
  validation_iter = chainer.iterators.SerialIterator(
      validation, batchsize, repeat=False, shuffle=False)

  # 4. Updater, Trainer を作成する
  updater = training.StandardUpdater(train_iter, optimizer, device=device)
  trainer = chainer.training.Trainer(updater, (n_epoch, "epoch"), out="out")

  # 5. Trainer の機能を拡張する
  trainer.extend(extensions.LogReport())
  trainer.extend(extensions.Evaluator(validation_iter, model, device=device), name="val")
  trainer.extend(extensions.PrintReport(
      ["epoch", "main/loss", "main/accuracy", "val/main/loss", "val/main/accuracy", "elapsed_time"]))
  trainer.extend(extensions.PlotReport(
      ["main/loss", "val/main/loss"], x_key="epoch", file_name="loss.png"))
  trainer.extend(extensions.PlotReport(
      ["main/accuracy", "val/main/accuracy"], x_key="epoch", file_name="accuracy.png"))
  trainer.extend(extensions.dump_graph("main/loss"))

  # 6. 訓練を開始する
  trainer.run()

学習とパラメータの検証を行う関数です. ここの中身は追って要勉強.


func_show_graph.py
import pydot
from IPython.display import Image, display

def show_graph():
    graph = pydot.graph_from_dot_file('out/cg.dot') # load from .dot file
    graph[0].write_png('graph.png')
    display(Image('graph.png', width=600, height=600))

ニューラルネットワークの, 入力から出力までの接続図っぽいものを出力する関数です.


func_show_loss_and_accuracy.py
def show_loss_and_accuracy():
    display(Image(filename="out/loss.png"))
    display(Image(filename="out/accuracy.png"))

損失関数の値と学習精度の推移がわかるグラフを表示する関数です.


func_show_examples.py
from chainer import Variable

def show_examples(model, test, device):
    plt.figure(figsize=(12,50))
    if device >= 0:
        model.to_cpu()
    for i in range(45, 105):
        data, label = test[i]
        x = Variable(np.asarray([data]))
        t = Variable(np.asarray([label]))
        y = model(x)
        prediction = y.data.argmax(axis=1)
        example = (data * 255).astype(np.int32).reshape(28, 28)
        plt.subplot(20, 5, i - 44)
        plt.imshow(example, cmap="gray")
        plt.title("No.{0}\nAnswer:{1}\nPredict:{2}".format(
            i,
            get_label_name(label),
            get_label_name(prediction[0])
        ))
        plt.axis("off")
    plt.tight_layout()

学習後に, 実際の画像+正解ラベルと, 学習結果(推論したラベル)を同時に表示します.


func_performance.py
def show_test_performance(model, test, device, batchsize=256):
    if device >= 0:
        model.to_gpu()
    test_iter = chainer.iterators.SerialIterator(
        test, batchsize, repeat=False, shuffle=False
    )
    test_evaluator = extensions.Evaluator(test_iter, model, device=device)
    results = test_evaluator()
    print("Test accuracy:", results["main/accuracy"])

テストデータでの, 最終的な精度を出力する関数です.


class_model.py
# 自分のモデルを作ってみよう
# エポック数10以下、かつ訓練時間100秒以内という制限の中で、テスト用データにおいて88%を超える精度を目指す
import chainer.functions as F
import chainer.links as L
from chainer import Chain
class MLPNew(Chain):

  def __init__(self):
      super(MLPNew, self).__init__()
      with self.init_scope():
          # Add more layers?
          self.l1 = L.Linear(784, 200) # Increase output node as (784, 300)?
          self.l2 = L.Linear(200, 200) # Increase nodes as (300, 300)?
          self.l3 = L.Linear(200, 10)  # Increase nodes as (300, 10)?

  def forward(self, x):
      h1 = F.tanh(self.l1(x))  # Replace F.tanh with F.sigmoid or F.relu?
      h2 = F.tanh(self.l2(h1)) # Replace F.tanh with F.sigmoid or F.relu?
      y = self.l3(h2)
      return y

学習モデルのクラスです. 総数やノード数, 活性化関数などを変更することができます.


do_train_and_validate.py
device = 0
n_epoch = 5     # Add more epochs?
batchsize = 256 # Increase/Decrease mini-batch size?

model = MLPNew()
classifier_model = L.Classifier(model)
optimizer = optimizers.SGD() # Default SGD(). Use other optimizer, Adam()?(Are there Momentum and AdaGrad?)

train_and_validate(
    classifier_model, optimizer, train, validation, n_epoch, batchsize, device)

学習に使用するデバイス(GPU か CPU か), エポック数, バッチサイズ, optimizer(パラメータの更新方法)を設定し, 実際に学習します.

次回

学習方法のオプション, 各パラメータ, 活性化関数を変更して, 学習への効果を確認します.

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5