2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Chainerのextensions.snapshot()で生成された複数のスナップショットから推論実行する

Last updated at Posted at 2018-02-11

Chainerにはextensions.snapshot()というスナップショットを出力する機能があるので、それを使って学習の推移を比較したい。ソースコードはここ

【追記】predictorの前にwith chainer.using_config('train', False):を設定しないと誤差伝番が実行されてしまうので注意

動作環境

  • Ubuntu 16.04.3 LTS
  • Python 3.5.2
  • chainer 3.2
  • opencv-python 3.2

実行イメージ

一番左が正解画像で、右に行くほど学習が進んだスナップショットとなる。学習が進むほど文字がシャープになっているのが確認できる。

snapshots.jpg

コード

メイン部 main()

[1]でネットワーク層を設定、[2]で入力する画像の生成、[3]でスナップショットを呼び出し学習モデルの生成と推論実行、[4]で各スナップショットから生成された出力画像を連結し、表示と保存を実行する。

predict_some_snapshot.py
import cv2
import numpy as np
import chainer
import chainer.links as L
from Lib.network import JC
import Lib.imgfunc as IMG
import Tools.func as F
from predict import getModelParam, predict, isImage, checkModelType

def main(args):
    # [1]
    snapshot_path, param = getSnapshotAndParam(args.snapshot_and_json)
    unit, ch, layer, sr, af1, af2 = getModelParam(param)
    model = L.Classifier(JC(
        n_unit=unit, n_out=ch, layer=layer, rate=sr,
        actfun_1=af1, actfun_2=af2
    ))
    # [2]
    img = getImage(args.jpeg, ch, args.random_seed)
    out_imgs = [img]
    # [3]
    for s in snapshot_path:
        load_path = checkModelType(s)
        try:
            chainer.serializers.load_npz(s, model, path=load_path)
        except:
            import traceback
            traceback.print_exc()
            print(F.fileFuncLine())
            exit()

        if args.gpu >= 0:
            chainer.cuda.get_device_from_id(args.gpu).use()
            model.to_gpu()

        with chainer.using_config('train', False):        
            out_imgs.append(predict(model, args, img, ch, -1))

    # [4]
    img = stackImages(out_imgs, args.img_rate)
    cv2.imshow('predict some snapshots', img)
    cv2.waitKey()
    cv2.imwrite(F.getFilePath(args.out_path, 'snapshots.jpg'), img)

推論実行 predict()

predict()をはじめとしたいくつかの関数は自作である。ソースコードは同じリポジトリにある。[1]で画像を圧縮して劣化させ、[2]で分割する。[3]でバッチサイズごとに実行し、[4]で分割した画像を再度結合し、[5]でサイズを入力画像と同じにして[6]で保存する。

predict.py
import cv2
import numpy as np
import chainer
import chainer.links as L
from chainer.cuda import to_cpu
from Lib.network import JC
import Lib.imgfunc as IMG
import Tools.func as F

def predict(model, args, img, ch, val):

    org_size = img.shape
    # [1]
    comp = IMG.encodeDecode([img], IMG.getCh(ch), args.quality)
    if(val >= 0):
        cv2.imwrite(
            F.getFilePath(args.out_path, 'comp-' +
                          str(val * 10).zfill(3), '.jpg'),
            comp[0]
        )

    # [2]
    comp, size = IMG.split(comp, args.img_size)
    imgs = []
    # [3]
    for i in range(0, len(comp), args.batch):
        x = IMG.imgs2arr(comp[i:i + args.batch], gpu=args.gpu)
        y = model.predictor(x)
        y = to_cpu(y.array)
        y = IMG.arr2imgs(y, ch, args.img_size * 2)
        imgs.extend(y)

    # [4]
    buf = [np.vstack(imgs[i * size[0]: (i + 1) * size[0]])
           for i in range(size[1])]
    img = np.hstack(buf)
    # [5]
    h = 0.5
    half_size = (int(img.shape[1] * h), int(img.shape[0] * h))
    flg = cv2.INTER_NEAREST
    img = cv2.resize(img, half_size, flg)
    img = img[:org_size[0], :org_size[1]]
    # [6]
    if(val >= 0):
        name = F.getFilePath(args.out_path, 'comp-' + str(val * 10 + 1).zfill(3), '.jpg')
        print('save:', name)
        cv2.imwrite(name, img)

    return img

以上。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?