LoginSignup
5

More than 5 years have passed since last update.

【Google Cloud Platform】初めてやるクラウドでの機械学習②

Last updated at Posted at 2017-05-02

Google Cloud Platformで機械学習をぶん回す、チュートリアルみたいになれば良いなと思います。

まずは、以下の2つの記事を読んで、一通り実行してほしいです。
【Google Cloud Platform】初めてやるクラウドでの機械学習①
【tensorflow】初めてやるニューラルネットワーク

ニューラルネットワークの記事は読まなくても大丈夫ですが、ニューラルネットワークで使ったコードに書き加えていくので、どこが変わったのか確認できると思います。

クラウドでのトレーニング

本来はローカルで実行可能か確認すべきですが、めんどくさいので省略します。

ログインから実行直前まで

login.
# 実行場所に移動
cd ~/google-cloud-ml/samples/mnist/deployable/

# ログイン
gcloud beta ml init-project

# ジョブの名前を設定します。(以下はジョブ名の一例です。)
JOB_NAME=mnist_deployable_${USER}_$(date +%Y%m%d_%H%M%S)

# プロジェクト名を取得します。
PROJECT_ID=`gcloud config list project --format "value(core.project)"`

# Cloud Storageのバケットのアドレスを設定します。(「プロジェクト名」+「-ml」)
TRAIN_BUCKET=gs://${PROJECT_ID}-ml

# トレーニングデータを出力するパスを設定します。
TRAIN_PATH=${TRAIN_BUCKET}/${JOB_NAME}

# 過去のトレーニングで出力したファイルを削除します。
gsutil rm -rf ${TRAIN_PATH}

クラウドのストレージにトレーニング用のcsvをアップしてください。
https://console.cloud.google.com/
image.png

クラウドに送信

以下のスクリプトを~/google-cloud-ml/samples/mnist/deployable/trainerに置いてください。

sazae_deploy.py
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import tensorflow.contrib.slim as slim
import tensorflow as tf

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')

# CSV を parse
# csvファイルがどこにあるかはちゃんと確認して指定してね
filename_queue = tf.train.string_input_producer(["gs://ストレージ名/sazae_train.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
output1, output2, output3, input1, input2 = tf.decode_csv(value, record_defaults=[[1], [1], [1], [1.0], [1.0]])
inputs = tf.stack([input1, input2])
output = tf.stack([output1,output2,output3])
# バッチ作成
inputs_batch, output_batch = tf.train.shuffle_batch([inputs, output], 4, capacity=1600, min_after_dequeue=400)
# 中間層の生成
hiddens = slim.stack(inputs_batch, slim.fully_connected, [2,4], activation_fn=tf.sigmoid, scope="hidden")
# 予測
prediction = slim.fully_connected(hiddens, 3, activation_fn=tf.nn.softmax, scope="output")
# 誤差の計算
loss = slim.losses.softmax_cross_entropy(prediction, output_batch)
# 計算した誤差から重みを更新
train_op = slim.optimize_loss(loss, slim.get_or_create_global_step(), learning_rate=0.01, optimizer='Adam')

# トレーニング関数
def run_training():
    with tf.Session() as sess:
        try:
          summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        except AttributeError:
          summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        saver = tf.train.Saver()

    # ミニバッチ処理ループ
        try:
          sess.run(init_op)
          for i in range(2000):
            _, t_loss = sess.run([train_op, loss])
            pre, kati, te = sess.run([prediction, output_batch, inputs_batch])
            if (i+1) % 100 == 0:
              print (t_loss)
            if (i + 1) % 1000 == 0 :
                checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
                saver.save(sess, checkpoint_file, global_step=i)
        finally:
          coord.request_stop()
        coord.join(threads)

# セッション開始関数
def main(_):
  run_training()

# モジュールとして実行可能にしている
if __name__ == '__main__':
  tf.app.run()

以下、実行コマンドになります。
クラウドに送信が成功しましたら、その下の3つを実行してください。

submit.
# クラウドに送信
gcloud ml-engine jobs submit training ${JOB_NAME} --module-name=trainer.sazae --package-path=trainer --staging-bucket="${TRAIN_BUCKET}" --region=us-central1 -- --train_dir="${TRAIN_PATH}/train" --model_dir="${TRAIN_PATH}/model"

# モデルの名前を作成
MODEL_NAME=mnist_${USER}_$(date +%Y%m%d_%H%M%S)

# クラウドでモデルを作成
gcloud ml-engine models create ${MODEL_NAME} --regions=us-central1

# 予測用バージョンを作成
gcloud beta ml-engine models versions create --origin=${TRAIN_PATH}/model/saved_model --model=${MODEL_NAME} v1

これで一通り、アップまで完了しました:point_up_2:

次回予告

次はアップした学習結果を利用して予測をしてみましょう。

【Google Cloud Platform】初めてやるクラウドでの機械学習③

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5