5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow2.0Advent Calendar 2019

Day 5

AWS SageMakerでTensorflow2.0のトレーニングジョブを実行する

Last updated at Posted at 2019-12-04

はじめに

Tensorflow 2.0の正式版が2019/10/1にリリースされましたが、2019/12/04現在SageMakerのTensorflowのlatest versionは1.14です。

このエントリでは、SageMakerのトレーニングコンテナを無理やりTensorflow 2.0対応させてトレーニングの実行を行わせることを目的としています。

SageMaker

SageMakerはAWSの提供する機械学習用のマネージド型サービスで、トレーニングジョブをコンテナで行うことで、リソースの最適化をすることができます。

SageMakerは大きく構築、トレーニング、デプロイの3機能に分かれますが、このエントリでは構築機能のノートブックを作成し、そこからTensorflow2.0に対応したトレーニングジョブを実行させます。

モデル

対象となるモデルはTensorflow2.0のチュートリアルにある、MNISTデータセットを対象としたシンプルなCNNです。

学習処理

Tensorflow2.0チュートリアルページにあるMNIST分類のためのシンプルなCNNをSageMakerを使ってトレーニングします。

学習処理はscript.pyという名前のファイルに記述し、実装部分はチュートリアルをそのまま流用します。

本来であれば、訓練および評価用のデータセットをS3上に用意してsagemaker.Tensorflow.estimatorに読み込ませる必要があるのですが、今回はkeras.datasetsにあるMNISTのデータセットをそのまま使っています。

script.py
# https://www.tensorflow.org/tutorials/images/cnn のほぼコピー
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

if __name__ == '__main__':

    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

    train_images = train_images.reshape((60000, 28, 28, 1))
    test_images = test_images.reshape((10000, 28, 28, 1))

    train_images, test_images = train_images / 255.0, test_images / 255.0
    
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(train_images, train_labels, epochs=5)

このプログラムを使ったトレーニングジョブを実行するためのノートブックを用意します。

まずはestimator作成などに必要なライブラリを読み込み、role情報をget_execution_role()から読み込みます。

notebook
import os
import numpy as np
import pandas as pd
import sagemaker
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow

sagemaker_session = sagemaker.Session()
role = get_execution_role()

次にestimator.Tensorflowを使い、トレーニングジョブに必要な情報を登録するのですが、気を付ける点が2点あり、1つめはTensorflow 2.0のトレーニングをscript modeで実行するため、そのままだとトレーニング終了後にコンテナが破棄されてしまいます。

これを防ぐため、トレーニング終了後にS3に学習済みモデルを出力する必要があります。

先ほどのscript.pyの最後にモデルの重みを保存するためのsave_weights()を追加します。

script.py
    ...
    model.fit(train_images, train_labels, epochs=args.epochs)
    model.save_weights(args.model_dir+'/model')

モデルの保存先はestimator.TensorFlowhyperparametersで指定でき、script.pyではArgumentParser()を使って引数を受け取ります。

script.py
    parser = argparse.ArgumentParser()
 
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    
    args, _ = parser.parse_known_args()

2点目はtensorflow 2.0でトレーニングのためにscript.py内でtensorflowのバージョンを2.0にアップグレードする必要があることで、そのためにはestimator.Tensorflowを通じて作成するトレーニングジョブ用コンテナのframework_versionを最新のもの(現時点では1.14.0)にする必要があります。

notebook
model_location = "s3://path/to/s3/cnn"
estimator = TensorFlow(entry_point='cnn_train.py',
                       role=role,
                       framework_version='1.14.0', # バージョンを最新に
                       hyperparameters={
                            'epochs' : 10,
                            'model-dir' : model_location, # モデルの出力先を指定
                       },
                       train_instance_count=1,
                       train_instance_type='ml.m5.xlarge',
                       script_mode=True,
                       py_version='py3')

さらにscript.py内部でpipmainを使ってtensorflow2.0に強制的にアップグレードさせます。

script.py
from pip._internal import main as pipmain
pipmain(['install','tensorflow==2.0'])

すべてまとめるとscript.pyはこのような形になります。

script.py
import os
import argparse

from pip._internal import main as pipmain
pipmain(['install','tensorflow==2.0'])

import tensorflow as tf
from tensorflow.keras import datasets, layers, models

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
 
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    
    args, _ = parser.parse_known_args()

    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

    train_images = train_images.reshape((60000, 28, 28, 1))
    test_images = test_images.reshape((10000, 28, 28, 1))

    train_images, test_images = train_images / 255.0, test_images / 255.0
    
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(train_images, train_labels, epochs=args.epochs)
    model.save_weights(args.model_dir+'/model')

ここまでできたらノートブックからestimator.fit()を実行します。

notebook
estimator.fit()

しばらくするとトレーニングジョブ用のコンテナが立ち上がります。

2019-12-04 02:23:40 Starting - Starting the training job...
2019-12-04 02:23:42 Starting - Launching requested ML instances.........
2019-12-04 02:25:15 Starting - Preparing the instances for training...
2019-12-04 02:26:05 Downloading - Downloading input data...

正常に処理が進むと、tensorflow2.0のアップグレードが始まります。

Collecting tensorflow==2.0
  Downloading https://files.pythonhosted.org/packages/46/0f/7bd55361168bb32796b360ad15a25de6966c9c1beb58a8e30c01c8279862/tensorflow-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (86.3MB)

途中でいくつかのライブラリでエラーになりますが、トレーニングジョブでawscliは使わないためそのまま処理を進めます。

ERROR: tensorboard 2.0.2 has requirement grpcio>=1.24.3, but you'll have grpcio 1.22.0 which is incompatible.
ERROR: awscli 1.16.196 has requirement botocore==1.12.186, but you'll have botocore 1.12.198 which is incompatible.
ERROR: awscli 1.16.196 has requirement PyYAML<=5.1,>=3.10; python_version != "2.6", but you'll have pyyaml 5.1.1 which is incompatible.

アップグレードが完了すると、トレーニングが始まります。

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
#015    8192/11490434 [..............................] - ETA: 1s#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01510838016/11490434 [===========================>..] - ETA: 0s#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#010#01511493376/11490434 [==============================] - 0s 0us/step
Train on 60000 samples

トレーニングの進行具合はAWSコンソールのトレーニングジョブからも確認でき、CloudWatchで進行具合を見ることもできます。

image.png
image.png

ジョブが正常に完了するとS3にモデルが出力されます。
image.png

モデルの評価

作成したモデルをノートブックで復元します。

notebook
model.load_weights('cnn/model')
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fcbd658eeb8>

精度を確認します。

notebook
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
print(test_acc)
0.9906

無事学習できているようです。

おわりに

SageMakerのトレーニングジョブ内で無理やりtensorflow 2.0にアップグレードすることで、トレーニングジョブをtensorflow 2.0に対応させることができます。

しかしながら、途中でERRORが出ていることからもわかるように、この方法がうまくいく保証はどこにもないため、AWSには一日も早くtensorflow 2.0に正式対応していただきたいです。

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?