Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

Cloud ML EngineでKerasモデルを学習した際、I/O関連で躓いた。その時の対処法。

More than 3 years have passed since last update.

Cloud ML Engine 便利です。GPUを用いて学習をしたい時は、設定ファイルを2行書けばOKです。

しかし、ローカルマシンで動作を確認しクラウド上で実行をしようとした際に、I/O関連でErrorが発生し何度か失敗してしまいました。具体的にエラーが出たのは、「エポックごとに学習済みのweightをGCSに保存しようとした」時などです。その時に対処したことを説明します。

Cloud ML Engine についてはこちらのポストなどが詳しく触れられています。

ここで行っていることは、GitHubにある以下のサンプルコード内の処理を参考にしています。

要約

Cloud ML Engineのクラウド上で学習を行う際に、ファイルの書き込み読み込みなどをしている場合は一部コードを書き直す必要があります。具体的に必要な処理は以下2つです。

  1. IO関連の組み込み関数は、from tensorflow.python.lib.io import file_ioパッケージを用いたものに書き換える
  2. Kerasのmodel.load_weights()ModelCheckpointを使っている場合は、一度クラウド上のローカルにデータを保存するという処理を挟む

詳しく見ていきます。

環境

以下の環境で実行しました。

version
CloudML runtime 1.4.0
TensorFlow 1.4.0
Keras 2.0.8
h5py 2.7.1

クラウド上でtrainigをするにあたって

私の場合は画像を用いたモデルを Keras を用いて構築しました。学習の際はローカルに保存してある画像を読み込みモデルに与え学習させ、エポックごとに途中経過のweightを保存するというよく見られる処理を行っていました。

IO関連のPython組み込み関数(例えばopen())であったりをそのまま使うことはできませんでした。なのでそれらに関する箇所を書き換える必要がありました。

1. IO関連の組み込み関数は、from tensorflow.python.lib.io import file_ioパッケージを用いたものに書き換える

解決策

以下のパッケージを代わりに使うことでGCSへのアクセスgs://を行える。

from tensorflow.python.lib.io import file_io

file_io ソース

以下にIO関連組み込み関数とfile_ioパッケージとの対応を示す。

os.listdir()        # -> file_io.list_directory()
os.mkdir(path)      # -> file_io.create_dir(path) 
os.path.exists()    # -> file_io.file_exists()
open(filename, 'r') # -> file_io.FileIO(filename, 'r')

2. Kerasのmodel.load_weights()ModelCheckpointを使っている場合は、一度クラウド上のローカルにデータを保存するという処理を挟む

h5pyとは

h5py は HDF5 フォーマットファイルを取り扱うための Python ライブラリーです。 Kerasを使っているときに、weigthを扱う場合には内部でこいつが使われています。こことか。

MLクラウド上でKerasのweightを扱おうとすると、IOErrorが発生します。原因は h5py です。サンプルコード中のコメントでは以下のように言及されています。

# Unhappy hack to work around h5py not being able to write to GCS.
# Force snapshots and saves to local filesystem, then copy them over to GCS.

解決策

h5py が原因でGCSと直接やり取りすることはできません。なので、クラウド上のローカルメモリにデータを一度キャッシュしてそのデータを h5py で扱うというステップを踏みます。

これは具体的な例をみた方が早いと思うので以下に例をしまします。

GCSのweightファイルを読み込む

以下の例ではGCSにあるmodel.h5ファイルを読み込んで、modelに読み込ませています。

with file_io.FileIO('gs://project-hoo/model.h5', 'r') as reader:
    with file_io.FileIO('model.h5', 'w+') as writer:
        writer.write(reader.read())

model.load_weights('model.h5', by_name=True)

GCSへweightファイルを保存する

以下の例ではエポックごとにweightをローカルに保存しておき、学習後にGCSに保存しています。

checkpoint = ModelCheckpoint('checkpoint.hdf5', save_weights_only=True)

# after finished ...
with file_io.FileIO('checkpoint.hdf5', 'r') as reader:
    with file_io.FileIO(os.path.join('gs://project-hoo', 'checkpoint.hdf5'), 'w+') as writer:
        writer.write(reader.read())

私は以下の関数を作成して使用しています。

def copy_to_gcs(gcs_path, file_path):
  with file_io.FileIO(file_path, 'r') as reader:
     with file_io.FileIO(os.path.join(gcs_path, file_path), 'w+') as writer:
        writer.write(reader.read())

def load_from_gcs(gcs_path, target_path):
    with file_io.FileIO(gcs_path, 'r') as reader:
        with file_io.FileIO(target_path, 'w+') as writer:
            writer.write(reader.read())

最後に

Cloud ML Engineに関連した情報が少なく、これが最適解であるかわかりません。

もしベターな方法をご存知の方がいたら教えてくださいm(_ _)m

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away