Help us understand the problem. What is going on with this article?

Chainerの新機能がすごい勢いで俺のソースコードを抽象化してくる

More than 3 years have passed since last update.

2016年7月12日に,Chainer v1.11.0がリリースされた.
Chainer Meetup 03や,Amazon Picking Challenge 2016の準備もあったろうに…なんて速度だ!
(第2位,おめでとうございます https://www.preferred-networks.jp/ja/news/amazon-picking-challenge-2016_result

この記事では,Chainerの新機能を美味しさと共に紹介していきたいと思う.

トレーニング部分のコードが抽象化できるようになった

Chainer Meetup 03では,「beam2dさんがゲロ吐きながらTrainコードを抽象化してくれている」という話だった.
あれからおよそ2週間,Chainerのコードは変化を遂げた.変化を遂げることを許された.

先日とあるライブコーディングイベントにて私が書いた,MLPでMNISTの分類を行うコードである.

from sklearn.datasets import fetch_mldata
from sklearn.cross_validation import train_test_split
import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import optimizers

mnist = fetch_mldata('MNIST original', data_home='.')
mnist.data = mnist.data.astype(np.float32) * 1.0 / 255.0
mnist.target = mnist.target.astype(np.int32)

train_data, test_data, train_label, test_label = train_test_split(mnist.data, mnist.target, test_size=10000,
                                                                  random_state=222)
print "data shape ", mnist.data.dtype, mnist.data.shape
print "label shape ", mnist.target.dtype, mnist.target.shape

class MnistModel(chainer.Chain):
    def __init__(self):
        super(MnistModel, self).__init__(
                l1=L.Linear(784, 100),
                l2=L.Linear(100, 100),
                l3=L.Linear(100, 10)
        )

    def __call__(self, x, t, train):

        x = chainer.Variable(x)
        t = chainer.Variable(t)

        h = F.relu(self.l1(x))
        h = F.relu(self.l2(h))
        h = self.l3(h)

        if train:
            return F.softmax_cross_entropy(h, t), F.accuracy(h, t)
        else:
            return F.accuracy(h, t)


model = MnistModel()
optimizer = optimizers.Adam()
optimizer.setup(model)

for epoch in range(100):
    model.zerograds()
    loss, acc = model(train_data, train_label, train=True)
    loss.backward()
    optimizer.update()
    print "acc  ", acc.data

acc = model(test_data, test_label, train=False)
print "acc test ", acc.data

綺麗かと言われると一悶着起きそうなコードだが,解説しながら5分ほどで書ける行数に収まっている.
50行ちょっとである.

そしてこれがChainer v1.11.0の新機能を用いて書いたMNISTの分類を行うコードだ.

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

class MnistModel(chainer.Chain):
    def __init__(self):
        super(MnistModel,self).__init__(
                l1 = L.Linear(784,100),
                l2 = L.Linear(100,100),
                l3 = L.Linear(100,10))

    def __call__(self,x):    
         h = F.relu(self.l1(x))
         h = F.relu(self.l2(h))
         return self.l3(h)


model = L.Classifier(MnistModel())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (100, 'epoch'), out="result")
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())

trainer.run()

なんと30行ちょっと!
チュートリアル用のモジュールを用意しているわけじゃないぞ!
(参考 http://www.slideshare.net/chabudaigaeshi/tensorflowchainer-63661517
↑repeatedlyさん,編集リクエストありがとうございます!
(他のフレームワークの悪口を言いたいわけではないです)
短さのみならず,その抽象度の高さにも目玉が飛び出る.

しかも,ちょっとリッチな感じで学習経過を教えてくれる.
これで「Chainerって学習結果の出力とか自分でやらなきゃダメなんでしょ?」という意地悪な質問にドヤ顔で返せるぞ!

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
1           0.333247    0.163566              0.90525        0.951
2           0.13979     0.120721              0.95835        0.9645
3           0.0972047   0.10129               0.970817       0.9682
4           0.0724406   0.0958347             0.9781         0.9712
5           0.05642     0.0935157             0.983267       0.9707
6           0.0443315   0.0999502             0.987183       0.9684
     total [##############################....................] 60.00%
this epoch [..................................................]  0.00%
      3600 iter, 6 epoch / 10 epochs

これだけ抽象度が高いと美少女Chainerユーザーの「チューしよう度」も異常な高まりを見せちゃいそうだ.
もちろん,従来の書き方も可能である

なんだかChainerを使えば,人生うまくいっちゃうような気さえしてきた.
ありがとうChainer

Variableでのラッピングが自動化された

( 2016年7月16日追記 )

Chainerの関数たちはこれまで,Variableを入力として受け付けていた.
最新版では,入力がndarrayであった場合自動でラッピングしてくれるようになった.
上記のMNIST(Trainer使わない版)の__call__部を,以下のように変更することが可能である.
ここで,xとtはxp.arrayを想定している.

    def __call__(self, x, t, train):

        #x = chainer.Variable(x)
        #t = chainer.Variable(t)

        h = F.relu(self.l1(x))
        h = F.relu(self.l2(h))
        h = self.l3(h)

        if train:
            return F.softmax_cross_entropy(h, t), F.accuracy(h, t)
        else:
            return F.accuracy(h, t)

シームレスなコードになった.
これは,変数の型について注意を払う必要が減り,エンジニアやネコの脳にかかる負担が軽減されたってことだ.

この変更の真の価値はインタラクティブシェルでの記述量が減る点にあると考えている.
旧バージョンではお試しデータを用意する際にもchainer.Variable( np.zeros( (3,224,224), dtype=float32 ))とカッコの対応に注意しながら入力するか,
インタラクティブシェルからお叱りを受けるかという状態であった.
これからはインタラクティブシェルからお叱りを受けることは減るだろう.Chainerと少しだけ仲良くなれた感覚がある.

/* Chainer v1.11.0+ の新機能について,分かり次第追記していきます.よろしくお願いします */


Chainer Meetup 03で発表されたアップデート内容(未確認)
- CaffeFunctionがPython3でも呼べるようになった.ResNetのCaffeモデルをサポートした.
- 重み初期化が簡単に(Initializerの追加)
- Float64が利用可能に
- CupyのプロファイルAPIが追加
- Variable.__getitem__の実装(スライスインデクシングなどに対応)
- roi_pooling_2dをはじめとする,様々なFunctions/Linksの追加

_329_
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした