1
5

More than 3 years have passed since last update.

ResNetを実装してみた!

Last updated at Posted at 2020-05-16

この記事を読者層について

DeepLearningの畳み込みニューラルネットワーク(以下CNN)の基礎知識があり、以下のような言葉の意味がわかる方
例)
- 畳み込み
- MaxPooling
- フィルター

ResNetとは

CNNの手法の1つであり、他のCNNよりも多くの層を追加する事ができる。
特徴としては、モジュールの最後に、モジュールのインプットデータをモジュール内で処理したデータに加算する(ショートカットコネクション)。
詳しくは、こちらです。

動作環境

GoogleColaboratory

サンプルプログラム

# 必要なライブラリーのインストール
import tensorflow.keras as keras
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.datasets import cifar10

# CIFAR10のデータを取得して、ベクトルに変換するクラス
class CIFAR10Dataset():
  def __init__(self):
    self.image_shape = (32, 32, 3)
    self.num_classes = 10

  # 学習データとテストデータを取得する。
  def get_batch(self):
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = [self.change_vec(img_data) for img_data in [x_train, x_test]]
    y_train, y_test = [self.change_vec(img_data, label_data=True) for img_data in [y_train, y_test]]
    return x_train, y_train, x_test, y_test

  # 目的変数の場合は、クラスベクトルに変更する。説明変数は標準化する。
  def change_vec(self, img_data, label=False):
    if label:
      data = keras.utils.to_categorical(img_data, self.num_classes)
    else:
      img_data = img_data.astype("float32")
      img_data /= 255
      shape = (img_data.shape[0],) + self.image_shape
      img_data = img_data.reshape(shape)
    return img_data

# ディープラーニングのモデルを設定して返す関数
def network(input_shape, num_classes,  count):
  filter_count = 32
  inputs = Input(shape=input_shape)
  x = Conv2D(32, kernel_size=3, padding="same", activation="relu")(inputs)
  x = BatchNormalization()(x)
  for i in range(count):
    shutcut = x #ショートカットコネクション用にモジュールの入力データを取得する。
    x = Conv2D(filter_count, kernel_size=3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(rate=0.3)(x)
    x = Conv2D(filter_count, kernel_size=3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Concatenate()([x, shutcut]) #ショートカットコネクション
    if i != count - 1:
      x = MaxPooling2D(pool_size=2)(x)
      filter_count = filter_count * 2
  x = Flatten()(x)
  x = BatchNormalization()(x)
  x = Dense(1024, activation="relu")(x)
  x = Dropout(rate=0.3)(x)
  x = BatchNormalization()(x)
  x = Dense(1024, activation="relu")(x)
  x = Dropout(rate=0.3)(x)
  x = BatchNormalization()(x)
  x = Dense(num_classes, activation="softmax")(x)
  model = Model(inputs=inputs, outputs=x)
  print(model.summary())
  return model

# モデルを学習させるクラス
class Trainer():
  # モデルをコンパイルして、学習するための設定をプライベートプロパティに設定する。
  def __init__(self, model, loss, optimizer):
    self._model = model
    self._model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=["accuracy"]
    )
    self._verbose = 1
    self._batch_size = 128
    self._epochs = 30

  # 実際の学習
  def fit(self, x_train, y_train, x_test, y_test):
    self._model.fit(
        x_train,
        y_train,
        batch_size=self._batch_size,
        epochs=self._epochs,
        verbose=self._verbose,
        validation_data=(x_test, y_test)
    )
    return self._model

dataset = CIFAR10Dataset() # データを取得するためのCIFAR10Datasetのインスタンス化
model = network(dataset.image_shape, dataset.num_classes, 4) #モデルの取得

x_train, y_train, x_test, y_test = dataset.get_batch()  # 学習データとテストデータの取得
trainer = Trainer(model, loss="categorical_crossentropy", optimizer="adam") # モデルとロス関数、最適化アルゴリズムを引数にして、Trainerのインスタンス化
model = trainer.fit(x_train, y_train, x_test, y_test) # モデルの学習

# モデルの評価
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss: ', score[0])
print('Test accuracy: ', score[1])

参考文献

直感DeepLearning
ResNetはなぜ良い性能を示すのか

1
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
5