LoginSignup
0
2

More than 5 years have passed since last update.

[tensorflow] modelの訓練

Last updated at Posted at 2018-03-19

流れ

CNNモデルクラス定義, 訓練データのbatch関数定義
⇒ モデル関数の定義(EstimatorSpecをReturn)
⇒ Estimatorの定義(モデル関数を使う) ⇒ 訓練開始

訓練LOGの出力設定

python
tf.logging.set_verbosity(tf.logging.INFO)
logging_hook = tf.train.LoggingTensorHook(
  tensors=tensors_to_log, every_n_iter=10)

checkpointの設定

python
my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)
mnist_classifier = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    params={
        'data_format': 'channels_last'
    },
    config=my_checkpointing_config)

Code with comments

python
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

class Model:

  def __init__(self, data_format):
    if data_format == 'channels_first':
      self._input_shape = [-1, 1, 28, 28]
    else:
      assert data_format == 'channels_last'
      self._input_shape = [-1, 28, 28, 1]
    self.conv1 = tf.layers.Conv2D(
        32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.conv2 = tf.layers.Conv2D(
        64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
    self.fc2 = tf.layers.Dense(10)
    self.dropout = tf.layers.Dropout(0.4)
    self.max_pool2d = tf.layers.MaxPooling2D(
        (2, 2), (2, 2), padding='same', data_format=data_format)

  def __call__(self, inputs, training):
    y = tf.reshape(inputs, self._input_shape)
    y = self.conv1(y)
    y = self.max_pool2d(y)
    y = self.conv2(y)
    y = self.max_pool2d(y)
    y = tf.layers.flatten(y)
    y = self.fc1(y)
    y = self.dropout(y, training=training)
    return self.fc2(y)
#モデル関数
def model_fn(features, labels, mode, params):
  #param (NCHWかNHWCか)
  model = Model(params['data_format'])
  image = features
  # EstimatorInstance.trainでifの中に入る
  # tf.estimator.ModeKeys.PREDICTでpredictもこの関数の中で定義可能
  # EstimatorInstance.predectでlabelsの予測
  if mode == tf.estimator.ModeKeys.TRAIN:
    # lossをどのように収束させるのか最適化方法の設定
    # GradientDescent, AdamOptimizerなどがあるが
    # 一般にAdamOptimizerが優れているとされる
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    # training=Trueはdropoutをtrainingだけ適用させるため
    logits = model(image, training=True)
    # lossの定義 (sigmoid, softmaxなどがある)
    # lossを最小にするネットワークを訓練
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    # accuracyの計算
    accuracy = tf.metrics.accuracy(
        labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
    # train_accuracyを定義
    tf.identity(accuracy[1], name='train_accuracy')
    # train_accuracyをsammaryに追加
    tf.summary.scalar('train_accuracy', accuracy[1])
    print("return model_fn ...")
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))

# mnist_dataのtrainデータを取り出す
def train_dataset(data_dir):
  data = input_data.read_data_sets(data_dir, one_hot=True).train
  return tf.data.Dataset.from_tensor_slices((data.images, data.labels))

def train_input_fn():
    dataset = train_dataset(data_dir)
    # epochはどれだけ繰り返すかの設定
    # batch 100, epoch 40だったら 50000 / 100 * 40 = 20000 stepで訓練が終わる
    dataset = dataset.shuffle(buffer_size=50000).batch(batch_size).repeat(
        train_epochs)
    # iterator
    (images, labels) = dataset.make_one_shot_iterator().get_next()
    return (images, labels)

# logの設定 : INFO
tf.logging.set_verbosity(tf.logging.INFO)

# mnistデータのpath, なかったらダウンロードしてarchiveから取り出す
data_dir = '/tmp/mnist_data'
# modelやmodelのcheckpointのセーブpath
model_dir = './mnist_model'
batch_size = 100
train_epochs = 40

# Estimatorの定義
mnist_classifier = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    params={
        'data_format': 'channels_last'
    })

tensors_to_log = {'train_accuracy': 'train_accuracy'}
# 10 stepごとにtrain_accuracyをtrain_accuracyとしてlog出力
logging_hook = tf.train.LoggingTensorHook(
  tensors=tensors_to_log, every_n_iter=10)
print("training starts ...")
# batch size 100で訓練開始、上の設定では20000 stepで終わる
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2