Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
3
Help us understand the problem. What is going on with this article?
@kingyo222

Chainer に IoU評価関数 を追加する(Evaluatorを自作)

More than 1 year has passed since last update.

更新履歴

  • 2018/11/17 集計時に誤差が発生するバグを修正しました

はじめに

こんにちは、のんびりエンジニアのたっつーです。
ブログを運営しているのでよろしければ見てください。

Chainer に IoU の評価関数を追加したかったので Evaluator を自作してみました。以下にように、学習中ログに IoU が追加して表示されるようになります。

詳細は、Chainer に IoU評価関数 を追加する(Evaluatorを自作) をご参照ください。

epoch       iou           main/loss   main/accuracy  ...
1           0             0.042866    0.993138       ...
2           0.000329707   0.0347965   0.993241       ...
3           0.00700626    0.0307309   0.993857       ...

使用例

iouのログを追加

trainer.extend(IouEvaluator(test_iter, model, device=gpu_id))

画面にiouを表示

trainer.extend(extensions.PrintReport(['epoch', 'iou', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))

ファイルにiouグラフを保存する

trainer.extend(extensions.PlotReport(['iou'], x_key='epoch', file_name='iou.png'))

ソースコード

IouEvaluator.py 本体

IouEvaluator.py
import chainer
from chainer import reporter as reporter_module
from chainer.training import extensions
from chainer import function
import numpy as np

class IouEvaluator(extensions.Evaluator):

    def evaluate(self):
        iterator = self._iterators['main']
        model = self._targets['main']
        eval_func = self.eval_func or model

        if self.eval_hook:
            self.eval_hook(self)

        if hasattr(iterator, 'reset'):
            iterator.reset()
            it = iterator
        else:
            it = copy.copy(iterator)

        summary = reporter_module.DictSummary()

        and_count = 0.
        or_count = 0.

        for batch in it:
            observation = {}

            with reporter_module.report_scope(observation):
                in_arrays = self.converter(batch, self.device)
                with function.no_backprop_mode():
                    if isinstance(in_arrays, tuple):
                        eval_func(*in_arrays)
                        ac, oc = self.iou(in_arrays)
                    elif isinstance(in_arrays, dict):
                        eval_func(**in_arrays)
                        ac, oc = self.iou(in_arrays)
                    else:
                        eval_func(in_arrays)
                        ac, oc = self.iou(in_arrays)
                    and_count = and_count + ac
                    or_count = or_count + oc

            # print(observation)
            summary.add(observation)

        iou_observation = {}
        if(or_count == 0):
            iou_observation['iou'] = 0.
        else:
            iou_observation['iou'] = float(and_count) / or_count
        summary.add(iou_observation)

        return summary.compute_mean()

    def iou(self, in_arrays):
        model = self._targets['main']

        _, labels = in_arrays
        if self.device >= 0:
            labels = chainer.cuda.to_cpu(labels)

        y = model.y.data
        if self.device >= 0:
            y = chainer.cuda.to_cpu(y)
        # print(y)
        y = y.argmax(axis=1)

        # print('labels', labels)
        # print('predct', y)
        and_count = (labels & y).sum()
        or_count = (labels | y).sum()
        return and_count, or_count

呼び出し側

def train(model_object, batchsize=64, gpu_id=0, max_epoch=20, dataset_func=None):

    # 1. Dataset
    dataset_name=dataset_func.__name__
    train, test = dataset_func()
    print('train=' + str(len(train)) + ' test=' + str(len(test)))
    # print(train)
    img, label = train[0]
    print(img.shape)

    # 2. Iterator
    train_iter = iterators.SerialIterator(train, batchsize)
    test_iter = iterators.SerialIterator(test, batchsize, False, False)

    # 3. Model
    model = L.Classifier(model_object)
    if gpu_id >= 0:
        model.to_gpu(gpu_id)

    # 4. Optimizer
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    # 5. Updater
    updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

    # 6. Trainer
    model_name = model_object.__class__.__name__
    outdir = '{}_{}'.format(model_name, dataset_name)
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    epoch_npz = '{}_{}_{}.npz'.format(model_name, dataset_name, '{.updater.epoch}')
    final_npz = '{}_{}_{}.npz'.format(model_name, dataset_name, 'fin')
    # print(epoch_npz)
    # print(final_npz)

    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out=outdir)

    # 7. Evaluator
    # TestModeEvaluator

    trainer.extend(extensions.LogReport())
    trainer.extend(IouEvaluator(test_iter, model, device=gpu_id))
    # trainer.extend(SemanticSegmentationEvaluator(test_iter, model, label_names=['negative','positive']))
    trainer.extend(extensions.PrintReport(['epoch', 'iou', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    trainer.extend(extensions.PlotReport(['iou'], x_key='epoch', file_name='iou.png'))
    trainer.extend(extensions.snapshot_object(model, epoch_npz))
    trainer.run()
    del trainer

    # save model
    serializers.save_npz(outdir + '/' + final_npz, model)

    return model

終わりに

よければ ブログ「初心者向けUnity情報サイト」の方にも色々記載しているのでぜひご参照いただければと思います。

3
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
kingyo222
Hololens、トマト、ロードバイク を愛している人

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
3
Help us understand the problem. What is going on with this article?