LoginSignup
8
20

More than 5 years have passed since last update.

CNNに入れるデータの作り方 (Chainer)

Last updated at Posted at 2017-05-22

ハッカソンで学習回して独自モデル作ろうとしたけどCNNに入れるデータの作り方を調べるのに時間がかかってしまったのでmemo

from PIL import Image
import numpy as np
import glob
import random

def load_image():
    filepaths = glob.glob('data/*.png')

    datasets = []
    for filepath in filepaths:
        img = Image.open(filepath).convert('L')  #Pillowで読み込み。'L'はグレースケールを意味する
        img = img.resize((32, 32)) # 32x32xにリサイズ
        label = int(filepath.split('/')[-1].split('_')[0]) # ラベル(0以上の整数) (自分の場合はよくファイル名の先頭にラベル名をつけている。)

        x = np.array(img, dtype=np.float32)
        x = x.reshape(1,32,32) # (チャネル、高さ、横幅)
        t = np.array(label, dtype=np.int32) 

        datasets.append((x,t)) # xとtをタプルでリストに入れる

    random.shuffle(datasets) # シャッフル
    train = datasets[:1000] # 最初の千個を学習用
    test = datasets[1000:1100] # 千個めから1100個目までをテスト用
    return train, test


def main(): # 以下、chainer exampleのcifer10を参考

    class_labels = 10
    train, test = load_image()

    model = L.Classifier(models.VGG.VGG(class_labels))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    optimizer = chainer.optimizers.MomentumSGD(args.learnrate)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)
    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(TestModeEvaluator(test_iter, model, device=args.gpu))

    # Reduce the learning rate by half every 25 epochs.
    trainer.extend(extensions.ExponentialShift('lr', 0.5),
                   trigger=(25, 'epoch'))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

8
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
20