LoginSignup
0
1

More than 3 years have passed since last update.

AWS SageMakerとTensorflow 2.0を使って独自データセットで学習するまで

Posted at

AWS SageMakerで独自データセットを使った学習のやり方についてまとまった記事がなかったので書いてみた。Google colabのようにコードをペタッと貼り付けるだけでは動かないのでいくつかの変更が必要だったりする。

動かしたいコード

[TensorFlow] AIを車両鉄に入門させてみたのコードをSageMaker上で動かしたい。データセットをダウンロードしてローカルで動くことをまず確認。

S3のセットアップ

S3とはファイル置き場みたいなもの。ここに学習データを入れておき、SageMakerから読み出す。出力ファイルもここに保存する。SageMakerを動かす前に設定しておかなければ使えない。

AWSマネジメントコンソールからS3を検索すると出てくる。
image.png

そしてバケットを作成をクリック。
image.png
適当なバケット名を入れてバケットを作成をクリックする。
image.png

こうするとSageMakerからは

s3://バケット名/

でアクセスできる。バケット名をクリックしてファイルをアップロードする手続きに入る。
image.png
まずフォルダの作成をクリックし、フォルダ名をyamanoteにする。フォルダが出来たらフォルダをクリックしてyamanoteフォルダの中に入り、imagesフォルダを作る。そしてその中に入って山手線の電車の画像をアップロードする。

ディレクトリ構造はこんな感じ。

yamanote/
|-- images/
    -- E231-yamanote/
    -- E233-chuo/
    -- E235-yamanote/

ここにはSageMakerからs3://バケット名/yamanote/でアクセスできる。

SageMakerの設定

AWSマネジメントコンソールからSageMakerを検索してクリック。
image.png
そしてノートブックインスタンスをクリックする。適当な名前を入れてノートブックインスタンスを作ると次のような画面になる。
image.png
これの開始をクリックしたらしばらくしたらInServiceになってJupyterを開くことができる。
image.png

SageMakerからTensorflowを実行するには、まず走らせたいPythonスクリプト(ここではyamanote.py)を用意し、これをSageMaker仕様に書き換える。そしてこれをJupyter notebookから実行する形をとる。まずはJupyter notebookの記述について記載する。

Jupyter notebookの記述

最初はファイルが何もない状態から始まる。Notebookを作るには右の方にあるNewのメニューからconda_tensorflow2_p36を選ぶ。
image.png
選ぶとJupyter notebookが起動するので以下のコードを書く。

yamanote.ipynb
import boto3
import json
import os
from sagemaker import Session, get_execution_role
from sagemaker.tensorflow import TensorFlow
from sagemaker.session import s3_input

sagemaker_session = Session()
sagemaker_role = get_execution_role()
SAGEMAKER_BUCKET = sagemaker_session.default_bucket()

S3_PATH_IMAGES = "s3://S3バケット名/yamanote/images/"

#Hyperparameters
hyper_param = {
    'batch-size': 20,
    'epochs': 10
}

estimator = TensorFlow(
    entry_point="yamanote.py",
    role=sagemaker_role,
    train_instance_count=1,
    train_instance_type="ml.m4.xlarge",
    framework_version="2.1.0",
    py_version='py3',
    script_mode=True,
    hyperparameters=hyper_param
)

estimator.fit({'images': s3_input(S3_PATH_IMAGES)})

SageMaker上でestimatorを作り、これにyamanote.pyを実行させる。そのときに渡すパラメータはここで指定する。

entroy_pointには走らせたいPythonスクリプトの名前を入れる。train_instance_typeには用いるインスタンス名を入れる。利用可能なインスタンス一覧はどこかで調べてください。

framework_versionはTensorflowのバージョンで今回は2.1.0を使う。2020/11/4現在2.2以上を指定するとエラーで動かない。py_versionはPythonのバージョン。Python3.6を使いたいのであればpy36とすればよい。

hyperparametersはJupyter notebookから渡す変数。今回は試しにbatch_sizeとepochsを渡すことにした。

最後にS3バケットへのパスをimagesというdictionaryで渡す。スクリプト側でimagesディレクトリへのパスとして受け取られる。

Pythonスクリプトの内容

yamanote.py
import argparse
import os
import numpy as np

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Flatten, Dropout
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    # Hyperparameterを渡す
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--batch-size', type=int, default=32)

    # 入力データ、学習後のモデル、ログファイルなどの格納場所は
    # 環境変数で設定されているので、その値をコマンドライン引数に渡しておく
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--images-dir', type=str, default=os.environ['SM_CHANNEL_IMAGES'])
    parser.add_argument('--output-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])

    args, _ = parser.parse_known_args()

    # Hyperparameterを受け取る
    batch_size = args.batch_size
    epochs = args.epochs

    # S3のデータを受け取る
    data_dir = args.images_dir

    # パラメータの設定
    classes = ["E231-yamanote", "E233-chuo", "E235-yamanote"]
    num_classes = len(classes)
    img_width, img_height = 128, 128
    feature_dim = (img_width, img_height, 3)

    # 独自データセットから学習データを作る
    datagen = ImageDataGenerator(
        rescale=1.0 / 255,
        zoom_range=0.2,
        horizontal_flip=True,
        validation_split=0.3
    )

    # training用データ
    train_generator = datagen.flow_from_directory(
        data_dir,
        subset="training",
        target_size=(img_width, img_height),
        color_mode="rgb",
        classes=classes,
        class_mode="categorical",
        batch_size=batch_size,
        shuffle=True)

    # validation用データ
    validation_generator = datagen.flow_from_directory(
        data_dir,
        subset="validation",
        target_size=(img_width, img_height),
        color_mode="rgb",
        classes=classes,
        class_mode="categorical",
        batch_size=batch_size,
        shuffle=False)

    # 画像数を取得し、1エポックのミニバッチ数を計算
    num_train_samples = train_generator.n
    num_validation_samples = validation_generator.n
    steps_per_epoch_train = (num_train_samples-1) // batch_size + 1
    steps_per_epoch_validation  = (num_validation_samples-1) // batch_size + 1

    # VGG16のfine tuning
    vgg16 = VGG16(include_top=False, weights="imagenet", input_shape=feature_dim)
    for layer in vgg16.layers[:15]:
        layer.trainable = False

    layer_input = Input(shape=feature_dim)
    layer_vgg16 = vgg16(layer_input)
    layer_flat = Flatten()(layer_vgg16)
    layer_fc = Dense(256, activation="relu")(layer_flat)
    layer_dropout = Dropout(0.5)(layer_fc)
    layer_output = Dense(num_classes, activation="softmax")(layer_dropout)
    model = Model(layer_input, layer_output)
    model.summary()
    model.compile(loss="categorical_crossentropy",
                  optimizer=SGD(lr=1e-3, momentum=0.9),
                  metrics=["accuracy"])

    # コールバック関数の定義
    cp_cb = ModelCheckpoint(
        filepath=os.path.join(args.output_dir, "weights.hdf5"),
        monitor="val_loss",
        save_best_only=True, 
        verbose=1,
        mode="auto")

    reduce_lr_cb = ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=1,
        verbose=1)

    history = model.fit(
        train_generator,
        steps_per_epoch=steps_per_epoch_train,
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=steps_per_epoch_validation,
        callbacks=[cp_cb, reduce_lr_cb])

これで動くかもしれない。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1