LoginSignup
4
4

More than 5 years have passed since last update.

ChainerのExtentionでTensorBoard出力するサンプル

Posted at

tensorboard-chainerを使うのですが、exampleにはTrainerを使った例がなかったので動作を確認するための小さいプログラムを書きました。

from tb_chainer import utils, SummaryWriter
from chainer.training import extensions
from os.path import join

class TensorBoardReport(chainer.training.Extension):
    def __init__(self, out_dir):
        self.writer = SummaryWriter(join(out_dir, datetime.now().strftime('%B%d-%H-%M-%S')))

    def __call__(self, trainer):
        observations = trainer.observation
        n_iter = trainer.updater.iteration
        for n, v in observations.items():
            if isinstance(v, chainer.Variable):
                value = v.data
            elif isinstance(v, chainer.cuda.cupy.ndarray):
                value = chainer.cuda.to_cpu(v)
            else:
                # てきとう
                value = v

            self.writer.add_scalar(n, value, n_iter)

        # とりあえずoptimizerはmainだけ
        link = trainer.updater.get_optimizer('main').target
        for name, param in link.namedparams():
            self.writer.add_histogram(name, chainer.cuda.to_cpu(param.data), n_iter)

とりあえずlossやaccuracyのグラフと、重みパラメータの勾配とヒストグラムだけ書いてみました。
validationのlossやaccuracyを書くためにEvaluatorが動くのと同じサイクルで動かすというごまかしを今回はしました。
tensorboardのimageセクションを使うなどの発展を考えると、このExtensionの中でEvaluationする必要があるか。(保留)

試しに↓のコードで学習を実行してみてログ出力してみました。
ほとんどchainerのcifar10のサンプルのままですが、環境にVGGがダウンロードされてなかったのでresnet50でやりました。

本筋とは関係ないですが、jupyterで実行しているためargsをリストで渡してパースしています。
重みパラメータが学習前後で大きく変わってる方が楽しいなと思ったので、finetuningなしでも試してみました。

from __future__ import print_function
import argparse
from os.path import join

import chainer
import chainer.links as L
from chainer import training
from chainer.training import extensions

from chainer.datasets import get_cifar10
from chainer.datasets import get_cifar100

from chainer.links import ResNet50Layers

class ResNet(chainer.links.ResNet50Layers):
    def __call__(self, x):
        return super().__call__(x)['prob']

def main(args):
    parser = argparse.ArgumentParser(description='Chainer CIFAR example:')
    parser.add_argument('--dataset', '-d', default='cifar10',
                        help='The dataset to use: cifar10 or cifar100')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--learnrate', '-l', type=float, default=0.05,
                        help='Learning rate for SGD')
    parser.add_argument('--epoch', '-e', type=int, default=300,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--from-scratch', '-s', dest='from_scratch', action='store_true', help='imagenet訓練済みの重みを使わずに学習する')
    parser.set_defaults(use_pretrained=False)
    args = parser.parse_args(args)

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train.
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    if args.dataset == 'cifar10':
        print('Using CIFAR10 dataset.')
        class_labels = 10
        train, test = get_cifar10()
    elif args.dataset == 'cifar100':
        print('Using CIFAR100 dataset.')
        class_labels = 100
        train, test = get_cifar100()
    else:
        raise RuntimeError('Invalid dataset choice.')

    if args.from_scratch:
        print('train from scratch')
        weight = None
    else:
        print('finetuning')
        weight = 'auto'

    resnet = ResNet(weight)
    resnet.fc6 = L.Linear(2048, class_labels)
    model = L.Classifier(resnet)
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    optimizer = chainer.optimizers.MomentumSGD(args.learnrate)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)
    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    eval_trigger = (1, 'epoch')
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu), trigger=eval_trigger)
    # tensorboard Evaluatorを同時に動かすというとりあえず対処
    trainer.extend(TensorBoardReport(args.out), trigger=eval_trigger)

    # Reduce the learning rate by half every 25 epochs.
    trainer.extend(extensions.ExponentialShift('lr', 0.5),
                   trigger=(25, 'epoch'))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

main(
    [
        '-d', 'cifar10',
        '-o', 'result/cifar10-tensorboard',
        '-e', '10',
        '-s'
    ])

main(
    [
        '-d', 'cifar10',
        '-o', 'result/cifar10-tensorboard',
        '-e', '10'
    ])

結果

赤がfinetuningなし、緑がありです。

よく見るグラフ
スクリーンショット 2017-11-09 12.33.46.png

ヒストグラム。finetuningなしの場合は最初裾野が広かったのが段々シュッとまとまる。
スクリーンショット 2017-11-09 12.30.52.png

ディストリビューション。これを見てどんな感想を抱けばいいのかな。
スクリーンショット 2017-11-09 12.35.04.png

勾配が消失するケースとか発散するケースを作って見たほうがいいのかな。

感想

10エポック分のResNet50の重み全部書きだしたらめちゃtensorboard重くなったので、選別しないと死にますね。

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4