LoginSignup
2
0

More than 3 years have passed since last update.

jpegファイルをpythonでリサイズしてからAlexNetで学習させてみた

Last updated at Posted at 2019-09-13

AlexNetを回すためにjpegファイルのリサイズを行ったので、今後のためにメモします。

目標は

画像のjpegファイルのリサイズを行うこと、そしてそのリサイズされたデータを使ってAlexNetで犬猫の画像を分類することです。ちなみに今回は、こちらに載っているコードを参考にさせて頂きました。

必要なライブラリのインストール

import os, sys, random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns

np.random.seed(722)

import keras
from keras.initializers import TruncatedNormal, Constant
from keras.models import Sequential
from keras.optimizers import SGD
from keras.layers import Input, Dropout, Flatten, Conv2D, MaxPooling2D, Dense, Activation, BatchNormalization
from keras.callbacks import Callback, EarlyStopping
from keras.utils.np_utils import to_categorical
from keras.datasets import mnist
from keras.datasets import cifar10
import cv2 as cv
import glob
import time
import logging
from tqdm import tqdm

1. リサイズ

1.1 パラメータの設定

base_in = "/自分のディレクトリ/cache"         #リサイズしたい画像ファイルのpath
base_out = "/自分のディレクトリ/cache-resize" #リサイズ後の保存場所
resize_size = 224                          #リサイズ後のサイズを任意で設定

1.2 リサイズした画像を保存するための下準備

logger_db = logging.getLogger('CompressImage')
logger_db.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)8s - %(message)s', datefmt='%m/%d/%Y %I:%M %p')
ch.setFormatter(formatter)
logger_db.addHandler(ch)

if __name__=='__main__':

    ### Create Path List
    start_time1 = time.time()
    path_list = [ x[0] for x in os.walk(base_in)]
    logger_db.debug('Time of creating path list: {}'.format(time.time()-start_time1))

    ### Create File List
    start_time2 = time.time()
    file_list = [ os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(base_in) for filename in filenames]
    logger_db.debug('Time of creating file list: {}'.format(time.time() - start_time2))
    # logger_db.debug(file_list[:10])
    logger_db.debug('File Number is {}'.format(len(file_list)))
    logger_db.debug(file_list[:5])
    ### Create New Directory
    start_time3 = time.time()
    new_path_list  = [ x.replace(base_in, base_out) for x in path_list]
    for path in tqdm(new_path_list):
      os.makedirs(path, exist_ok=True)
    logger_db.info('Time of Creating Directory: {}'.format(time.time() - start_time3))

1.3 リサイズ

for image_file_path in tqdm(file_list):
        img = cv.imread(image_file_path)
        try:
            img  = cv.resize(img, dsize=( resize_size, resize_size))
            path = image_file_path.replace( base_in, base_out)
            cv.imwrite(path,img)
        except:
            print("error",path)

2. AlexNetで学習

ここからAlexNetで学習させていきます。

2.1 関数とモデルの作成

def conv2d(filters, kernel_size, strides=1, bias_init=1, **kwargs):
    trunc = TruncatedNormal(mean=0.0, stddev=0.01)
    cnst = Constant(value=bias_init)
    return Conv2D(
    filters,
    kernel_size,
    strides=strides,
    padding='same',
    activation='relu',
    kernel_initializer=trunc,
    bias_initializer=cnst,
    **kwargs
    )


def dense(units, **kwargs):
    trunc = TruncatedNormal(mean=0.0, stddev=0.01)
    cnst = Constant(value=1)
    return Dense(
        units,
        activation='tanh',
        kernel_initializer=trunc,
        bias_initializer=cnst,
        **kwargs
    )

def AlexNet():
    model = Sequential()

    # 第1畳み込み層
    model.add(conv2d(96, 11, strides=(4,4), bias_init=0, input_shape=(ROWS, COLS, 3)))
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
    model.add(BatchNormalization())

    # 第2畳み込み層
    model.add(conv2d(256, 5, bias_init=1))
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
    model.add(BatchNormalization())

    # 第3~5畳み込み層
    model.add(conv2d(384, 3, bias_init=0))
    model.add(conv2d(384, 3, bias_init=1))
    model.add(conv2d(256, 3, bias_init=1))
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
    model.add(BatchNormalization())

    # 密結合層
    model.add(Flatten())
    model.add(dense(4096))
    model.add(Dropout(0.5))
    model.add(dense(4096))
    model.add(Dropout(0.5))

    # 読み出し層
    model.add(Dense(2, activation='softmax'))

    model.compile(optimizer=SGD(lr=0.01), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def plot_history(history):
    plt.plot(history.history['acc'],"o-",label="accuracy")
    plt.plot(history.history['val_acc'],"o-",label="val_acc")
    plt.title('model accuracy')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.ylim(0, 1)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.show()

    plt.plot(history.history['loss'],"o-",label="loss",)
    plt.plot(history.history['val_loss'],"o-",label="val_loss")
    plt.title('model loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.ylim(ymin=0)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.show()

def read(name):
    return cv2.imread(name, cv2.IMREAD_COLOR)

def convert(img):
    return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC)

def save(name, img):
    cv2.imwrite(CACHE_DIR + name, img)
    return img

def ls(dirname):
    return [dirname + i for i in os.listdir(dirname)]

if __name__=='__main__':

    print('Main starts')
    ROWS = 224;
    COLS = 224
    CHANNELS = 3

    TRAIN_DIR = 'train/'
    TEST_DIR = 'test/'
    CACHE_DIR = 'cache/'

    FORCE_CONVERT = False

2.2 ここからリサイズしたファイルが登場

    TEST_DIR = './cache-resize/train'
    TRAIN_DIR = './cache-resize/test'
    trainfile_list = [ os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(TRAIN_DIR) for filename in filenames]

    train  =  np.array([read(i) for i in trainfile_list])
    print("Train shape: {}".format(train.shape))

    base_in = "/自分のディレクトリ/cache-resize"  #1.1で作ったリサイズ後の保存場所を入れる

    # ラベルの作成
    labels = []
    for i in trainfile_list:
        if 'dog' in i:
            labels.append(0)
        else:
            labels.append(1)

    sns.countplot(labels)
    plt.title('Dogs and Cats')

    labels = to_categorical(labels)

    model = AlexNet()
    print(model.summary())
    early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, mode='auto')
    history = model.fit(train, labels, epochs=15, batch_size=128, shuffle=True, validation_split=0.25, callbacks=[early_stopping])

    plot_history(history)

役に立つと嬉しいです:)

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0