LoginSignup
16
9

More than 3 years have passed since last update.

はじめに

TensorFlowを使っていると、たくさんのライブラリや様々な実装の仕方があることがわかると思います。これはtf2が出てきても依然として状況は変わってないように思えます。自由度が高い反面、公開リポジトリのコードを読む際にも、実装の仕方が様々で読み解くのが大変になってしまっていると思います。

今回は、TensorFlowで実験コードを書く際に、個人的に良いと思う実装パターンを書いていこうと思います。結論を言うと、tf.keras.Modelのエコシステムを存分に利用して実装しよう、という主張です。

※ツッコミも大歓迎です。

tf2での実装方法

ここでは主にトレーニングループの実装に焦点をおきます。

トレーニングループの実装

  • tf.keras.Model.fit
    • 基本的に、tfに実装済みの損失関数等を利用するので自前実装は必要ない
    • コンパイル時に損失関数やオプティマイザなどを指定して実行するだけ
    • 単純な教師あり学習などでよく使われている印象
  • custom training loops
  • tensorflow estimator
    • 実験の様々な処理をカプセル化する
    • 自由度が高いが、tf.kerasと比べて複雑で可読性が悪くなる印象
    • どんどん使われなくなってきていると思うが、ネット上では見かける実装方法
    • https://www.tensorflow.org/guide/estimator

主に分けてこれらの実装がよくみられます。特に「こう実装するべき!」という主張はないため、これらが混在している印象です。
tensorflowの初学者は特に実装の方法がたくさんあって苦戦を強いられる状況になってしまっていると思います。
自分自身、TensorFlowでコードを書く時の推奨の実装パターンのようなものがあったら嬉しいなと思っていたので、今回は個人的に良いと思う実装パターンを書いていきたいと思います。

tf.keras.Modelに寄せた実装

ここが今回のメイン部分になります。まずは、tf.keras.Modelに寄せて実装すると何が嬉しいか書いていきます。

いい点

  • 便利なコールバックを簡単に使える
    • LRスケジュールとか、TensorBoardのログ記録とかも優秀
  • compileが優秀
    • model.saveする際にcompileで設定した情報も保存してくれる
      • compileしていれば学習途中のoptimizerの状態保存までしてくれる => 実験の再現性
    • TensorBoardコールバックなどでも、ここで設定したロスやメトリクスを自動で記録してくれる

逆に、これらの部分は、custom training loopsなど、他の方法で実装する際には自前で実装しなければなりません。

tf.keras.Model.fitを使わない場合

tf.keras.Modelのエコシステムが勝手にやってくれる部分を自前で実装していく必要があります。

tensorboard

例えば、tensorboard用のmetricsやログを取りたかったら、summary_writerを必要な分定義し、トレーニングループ内などでログ記録用のコードを色々と書かなければならなくなり、見栄えが悪くなります。

train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

for epoch in range(EPOCHS):
  for (x_train, y_train) in train_dataset:
    train_step(model, optimizer, x_train, y_train)
  with train_summary_writer.as_default():
    tf.summary.scalar('loss', train_loss.result(), step=epoch)
    tf.summary.scalar('accuracy', train_accuracy.result(), step=epoch)

  for (x_test, y_test) in test_dataset:
    test_step(model, x_test, y_test)
  with test_summary_writer.as_default():
    tf.summary.scalar('loss', test_loss.result(), step=epoch)
    tf.summary.scalar('accuracy', test_accuracy.result(), step=epoch)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(), 
                         train_accuracy.result()*100,
                         test_loss.result(), 
                         test_accuracy.result()*100))

  # Reset metrics every epoch
  train_loss.reset_states()
  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

引用元

tf.keras.Model.fitのコールバックであればその設定をするだけでOKです。

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

optimizer

また、トレーニングは途中で停止・異常終了された時に適切に復旧できるように実装されている必要もあります(トレーニングインスタンスのプリエンプトが主な理由)。
checkpointを保存している例は見ますが、optimizerの状態保存などはあまり考慮されてないケースをよく見かます。これもトレーニングの継続には大事な部分でありますが、自前で実装するのは少し面倒な部分です。

これも、tf.keras.Model.compileでコンパイルしていれば、(公式で実装済みのoptimizerであれば)model.save時にデフォルトで保存してくれるようになっています。model.saveinclude_optimizer引数で制御することも可能です。

これらのように、tf.keras.Modelに寄せて実装することで、自分で書かなければいけない部分を大幅に減らし、コードの可読性を保つことができます。
ここではひとまずこれらの例を上げましたが、他にも便利な部分や、今後追加されていく機能も多いと思います。

カスタマイズ性

色々と便利な点を書いていきましたが、カスタマイズ性の観点が肝になってくると思います。
そもそも複雑なトレーニングループを書きたいからcustom training roopsで実装しているという方が多いかと思いますが、tf.keras.Model.fitでトレーニングを実行する場合にも、実はcustom training roopsと同様にトレーニングループを書くことは可能です。

tf.keras.Model.train_stepのオーバーライド

上記公式ドキュメントにも書かれていることなので、知っている人は知ってると思います。当たり前な話ですが、fit内部で使ってる関数をオーバーライドするということです。

手順

以下のようにtf.keras.Modelのサブクラスを作成します。

class MyModel(tf.keras.Model):
    """Example in overridden `tf.keras.Model.train_step`

    Arguments:
        data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
    Returns:
        The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not
    provided.
    """
    def train_step(self, data):
        # If `sample_weight` is not provided, all samples will be weighted
        # equally.
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {m.name: m.result() for m in self.metrics}

参照元。一部改変。

モデルを作成する時に以下のようにカスタムのモデルクラスでラップすればOKです。

def build_model(input_shape: List, num_classes: int):
    """トレーニングに使用するモデルを作成する.

    Args:
        input_shape {List} -- 入力データのshape.
        num_classes {int} -- クラス数.
    """
    # 例
    inputs = tf.keras.Input(shape=input_shape)
    outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs)

    model = MyModel(inputs, outputs)
    # カスタムモデルクラスをしようしない場合は以下
    #model = tf.keras.Model(inputs, outputs)
    return model

たったこれだけでtf.keras.Model.fitを使ってトレーニングの実行が可能になります。実際、custom training roopsの方法で実装した関数をほぼそのままtf.keras.Model.train_stepに移植するだけで良いと思います。

その他、損失関数やOptimizerなども当然カスタムも可能です。tf.keras.Model.compileで指定すれば、self.optimizerself.compiled_lossなどでトレーニングループ内からアクセスできます。

def custom_loss_func(y: Tensor, y_pred: Tensor) -> Tensor:
    """カスタムの損失関数を実装する.

    Args:
        y {Tensor} -- 例えば教師ラベル
        y_pred {Tensor} -- 例えばモデルの予測値

    Returns:
        Tensor -- 損失の計算結果
    """
    loss = y - y_pred
    return loss

model.compile(loss=custom_loss_func,
              optimizer=custom_optimizer,
              metrics=[custom_metrics])

tf.keras.Model.fitに寄せた実装にするもう一つの利点は、コードの共通化が可能なところです。ここまで書いてきたカスタマイズできる部分は必要に応じてカスタマイズし、残りは毎回同様のコードを使いまわせることになります(例えば以下)。

def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # tf.distribute.Strategyを使うかどうか
    if FLAGS.use_tpu:
        # Setup tpu strategy
        cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(cluster)
        tf.tpu.experimental.initialize_tpu_system(cluster)
        distribute_strategy = tf.distribute.TPUStrategy(cluster)

        with distribute_strategy.scope():
            model = build_model(FLAGS.input_shape, num_classes=FLAGS.num_classes)
            optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
    elif FLAGS.use_gpu:
            # Setup mirrored strategy
            distribute_strategy = tf.distribute.MirroredStrategy()
            with distribute_strategy.scope():
                model = build_model(FLAGS.input_shape, num_classes=FLAGS.num_classes)
                optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
    else:
        model = build_model(FLAGS.input_shape, num_classes=FLAGS.num_classes)
        optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)

    model.compile(loss=custom_loss_func,
                  optimizer=optimizer,
                  metrics=["accuracy"])
    model.summary()

    tboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f"{FLAGS.job_dir}/logs", histogram_freq=1)
    callbacks = [tboard_callback]

    train_ds = get_dataset(FLAGS.dataset, FLAGS.global_batch_size, "train")
    valid_ds = get_dataset(FLAGS.dataset, FLAGS.global_batch_size, "valid")

    for epoch in range(FLAGS.epochs):
        model.fit(train_ds, validation_data=valid_ds, callbacks=callbacks, initial_epoch=epoch, epochs=epoch+1)
        model.save(f"{FLAGS.job_dir}/checkpoints/{epoch+1}", include_optimizer=True)

    model.save(f"{FLAGS.job_dir}/saved_model", include_optimizer=False)

分散学習の場合も、基本的には、ほぼ変わらない実装で機能します。

サンプルリポジトリ

雛形のコードとサンプル実装を以下のリポジトリに載せています。

(今はSimCLRの実装例しか載せてませんが)複雑な手法の実装も可能であることがわかると思います。

なお、宣伝的になってしまいますが、SimCLRのこの実装に関しては、技術書典10でmixi tech note #5の2章でも掲載予定です。より詳しい情報や、興味がある方は読んでいただけると幸いです。
=> SimCLRの実装としてはMinimalな実装で扱いやすく、分散学習にも対応しているという点で、ある程度需要があるんじゃないかというのもあってこの題材をテーマにしてます。

まとめ

このように、tf.keras.Model.fitcompileで指定できるトレーニング手法しか使えないわけではなく、かなり拡張性が確保されています。tf.keras.Modelのエコシステムを理解すれば、その恩恵を受けつつ、かなり自由度高くトレーニングループを書くことが可能です。
また、自分でコードを書く部分は最小限に抑えられるので、可読性や拡張性の観点で優れているのではないかと思っています。

公式ドキュメントでは、やり方は小さく書かれていますが、あまりこういう実装がいいという主張はなかったように思うので、ここで紹介させていただきました。

現状ではまだ様々な実装のされ方がしていて読むのが辛い状況ですが、これに限らず実装の仕方がもう少し統一されるようになってくれれば良いなぁと思っています。

16
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
9