Help us understand the problem. What is going on with this article?

jpegファイルをpythonでリサイズしてからAlexNetで学習させてみた

AlexNetを回すためにjpegファイルのリサイズを行ったので、今後のためにメモします。

目標は

画像のjpegファイルのリサイズを行うこと、そしてそのリサイズされたデータを使ってAlexNetで犬猫の画像を分類することです。ちなみに今回は、こちらに載っているコードを参考にさせて頂きました。

必要なライブラリのインストール

import os, sys, random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns

np.random.seed(722)

import keras
from keras.initializers import TruncatedNormal, Constant
from keras.models import Sequential
from keras.optimizers import SGD
from keras.layers import Input, Dropout, Flatten, Conv2D, MaxPooling2D, Dense, Activation, BatchNormalization
from keras.callbacks import Callback, EarlyStopping
from keras.utils.np_utils import to_categorical
from keras.datasets import mnist
from keras.datasets import cifar10
import cv2 as cv
import glob
import time
import logging
from tqdm import tqdm

1. リサイズ

1.1 パラメータの設定

base_in = "/自分のディレクトリ/cache"         #リサイズしたい画像ファイルのpath
base_out = "/自分のディレクトリ/cache-resize" #リサイズ後の保存場所
resize_size = 224                          #リサイズ後のサイズを任意で設定

1.2 リサイズした画像を保存するための下準備

logger_db = logging.getLogger('CompressImage')
logger_db.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)8s - %(message)s', datefmt='%m/%d/%Y %I:%M %p')
ch.setFormatter(formatter)
logger_db.addHandler(ch)

if __name__=='__main__':

    ### Create Path List
    start_time1 = time.time()
    path_list = [ x[0] for x in os.walk(base_in)]
    logger_db.debug('Time of creating path list: {}'.format(time.time()-start_time1))

    ### Create File List
    start_time2 = time.time()
    file_list = [ os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(base_in) for filename in filenames]
    logger_db.debug('Time of creating file list: {}'.format(time.time() - start_time2))
    # logger_db.debug(file_list[:10])
    logger_db.debug('File Number is {}'.format(len(file_list)))
    logger_db.debug(file_list[:5])
    ### Create New Directory
    start_time3 = time.time()
    new_path_list  = [ x.replace(base_in, base_out) for x in path_list]
    for path in tqdm(new_path_list):
      os.makedirs(path, exist_ok=True)
    logger_db.info('Time of Creating Directory: {}'.format(time.time() - start_time3))

1.3 リサイズ

for image_file_path in tqdm(file_list):
        img = cv.imread(image_file_path)
        try:
            img  = cv.resize(img, dsize=( resize_size, resize_size))
            path = image_file_path.replace( base_in, base_out)
            cv.imwrite(path,img)
        except:
            print("error",path)

2. AlexNetで学習

ここからAlexNetで学習させていきます。

2.1 関数とモデルの作成

def conv2d(filters, kernel_size, strides=1, bias_init=1, **kwargs):
    trunc = TruncatedNormal(mean=0.0, stddev=0.01)
    cnst = Constant(value=bias_init)
    return Conv2D(
    filters,
    kernel_size,
    strides=strides,
    padding='same',
    activation='relu',
    kernel_initializer=trunc,
    bias_initializer=cnst,
    **kwargs
    )


def dense(units, **kwargs):
    trunc = TruncatedNormal(mean=0.0, stddev=0.01)
    cnst = Constant(value=1)
    return Dense(
        units,
        activation='tanh',
        kernel_initializer=trunc,
        bias_initializer=cnst,
        **kwargs
    )

def AlexNet():
    model = Sequential()

    # 第1畳み込み層
    model.add(conv2d(96, 11, strides=(4,4), bias_init=0, input_shape=(ROWS, COLS, 3)))
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
    model.add(BatchNormalization())

    # 第2畳み込み層
    model.add(conv2d(256, 5, bias_init=1))
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
    model.add(BatchNormalization())

    # 第3~5畳み込み層
    model.add(conv2d(384, 3, bias_init=0))
    model.add(conv2d(384, 3, bias_init=1))
    model.add(conv2d(256, 3, bias_init=1))
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
    model.add(BatchNormalization())

    # 密結合層
    model.add(Flatten())
    model.add(dense(4096))
    model.add(Dropout(0.5))
    model.add(dense(4096))
    model.add(Dropout(0.5))

    # 読み出し層
    model.add(Dense(2, activation='softmax'))

    model.compile(optimizer=SGD(lr=0.01), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def plot_history(history):
    plt.plot(history.history['acc'],"o-",label="accuracy")
    plt.plot(history.history['val_acc'],"o-",label="val_acc")
    plt.title('model accuracy')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.ylim(0, 1)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.show()

    plt.plot(history.history['loss'],"o-",label="loss",)
    plt.plot(history.history['val_loss'],"o-",label="val_loss")
    plt.title('model loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.ylim(ymin=0)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.show()

def read(name):
    return cv2.imread(name, cv2.IMREAD_COLOR)

def convert(img):
    return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC)

def save(name, img):
    cv2.imwrite(CACHE_DIR + name, img)
    return img

def ls(dirname):
    return [dirname + i for i in os.listdir(dirname)]

if __name__=='__main__':

    print('Main starts')
    ROWS = 224;
    COLS = 224
    CHANNELS = 3

    TRAIN_DIR = 'train/'
    TEST_DIR = 'test/'
    CACHE_DIR = 'cache/'

    FORCE_CONVERT = False

2.2 ここからリサイズしたファイルが登場

    TEST_DIR = './cache-resize/train'
    TRAIN_DIR = './cache-resize/test'
    trainfile_list = [ os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(TRAIN_DIR) for filename in filenames]

    train  =  np.array([read(i) for i in trainfile_list])
    print("Train shape: {}".format(train.shape))

    base_in = "/自分のディレクトリ/cache-resize"  #1.1で作ったリサイズ後の保存場所を入れる

    # ラベルの作成
    labels = []
    for i in trainfile_list:
        if 'dog' in i:
            labels.append(0)
        else:
            labels.append(1)

    sns.countplot(labels)
    plt.title('Dogs and Cats')

    labels = to_categorical(labels)

    model = AlexNet()
    print(model.summary())
    early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, mode='auto')
    history = model.fit(train, labels, epochs=15, batch_size=128, shuffle=True, validation_split=0.25, callbacks=[early_stopping])

    plot_history(history)

役に立つと嬉しいです:)

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away