ハッカソンで学習回して独自モデル作ろうとしたけどCNNに入れるデータの作り方を調べるのに時間がかかってしまったのでmemo

from PIL import Image
import numpy as np
import glob
import random

def load_image():
    filepaths = glob.glob('data/*.png')

    datasets = []
    for filepath in filepaths:
        img = Image.open(filepath).convert('L')  #Pillowで読み込み。'L'はグレースケールを意味する
        img = img.resize((32, 32)) # 32x32xにリサイズ
        label = int(filepath.split('/')[-1].split('_')[0]) # ラベル(0以上の整数) (自分の場合はよくファイル名の先頭にラベル名をつけている。)

        x = np.array(img, dtype=np.float32)
        x = x.reshape(1,32,32) # (チャネル、高さ、横幅)
        t = np.array(label, dtype=np.int32) 

        datasets.append((x,t)) # xとtをタプルでリストに入れる

    random.shuffle(datasets) # シャッフル
    train = datasets[:1000] # 最初の千個を学習用
    test = datasets[1000:1100] # 千個めから1100個目までをテスト用
    return train, test


def main(): # 以下、chainer exampleのcifer10を参考

    class_labels = 10
    train, test = load_image()

    model = L.Classifier(models.VGG.VGG(class_labels))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    optimizer = chainer.optimizers.MomentumSGD(args.learnrate)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)
    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(TestModeEvaluator(test_iter, model, device=args.gpu))

    # Reduce the learning rate by half every 25 epochs.
    trainer.extend(extensions.ExponentialShift('lr', 0.5),
                   trigger=(25, 'epoch'))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.