15
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Tensorflow 2.0を試してみる

Last updated at Posted at 2019-03-23

# はじめに
正式リリース前ですがTensorflow 2.0少しだけ試してみます。
ソースコード

Tensorflow 2.0で変わること

2.0で大きく変わることは

  • Eager executionがdefaultになる
  • 重複したAPIを統一(tf.layerstf.keras.layersなど)
  • Contribが一掃される

で、一番の目玉は「Eager executionがdefaultになる」だと思います。
今回はEager executionとkeras実装での学習を比較してみます。
基本的な実装は公式のチュートリアルを参考にしています。

Eager executionとは

Tensorflowはdefine and runという方式で動くライブラリでしたが
Eager executionでは、pytorchやchainerのようにdefine by runで実行されます。
計算グラフを定義しながら実行するdefine by runではモデルを柔軟に定義することができるためRNNなどの実装では好まれていますし、最近のpytorchの勢いからして今後の主流になっていきそうです。

やること

  • Tensorflow 2.0.0-alpha0 を使ってみる
  • Mnistのサンプルコードを動かす
  • Eager executionとkeras実装で学習を比較する

実行環境

  • tensorflow-datasets==1.0.1
  • tensorflow==2.0.0-alpha0

モデルの定義

trainer.py
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

モデルは簡単なCNNです。kerasのModelクラスを継承してモデルを定義することができるようになったみたいです。(kerasの少し前のバージョンからできたみたいですが全然気づかなかった)
pytorchやchainerを使ったことがある人は馴染みのある書き方で、個人的にもわかりやすいと思うのでこの書き方に慣れた方がいいと思います。

Trainer

EagerTrainer

trainer.py
class EagerTrainer(object):
    def __init__(self):
        self.model = MyModel()
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
        self.optimizer = tf.keras.optimizers.Adam()

        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

        self.test_loss = tf.keras.metrics.Mean(name='test_loss')
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

    @tf.function
    def train_step(self, image, label):
        with tf.GradientTape() as tape:
            predictions = self.model(image)  # 順伝播の計算
            loss = self.loss_object(label, predictions)  # lossの計算
        gradients = tape.gradient(loss, self.model.trainable_variables)  # 勾配の計算
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))  # パラメータの更新

        self.train_loss(loss)
        self.train_accuracy(label, predictions)

    @tf.function
    def test_step(self, image, label):
        predictions = self.model(image)
        t_loss = self.loss_object(label, predictions)

        self.test_loss(t_loss)
        self.test_accuracy(label, predictions)

    def train(self, epochs, training_data, test_data):
        template = 'Epoch {}, Loss: {:.5f}, Accuracy: {:.5f}, Test Loss: {:.5f}, Test Accuracy: {:.5f}, elapsed_time {:.5f}'

        for epoch in range(epochs):
            start = time.time()
            for image, label in tqdm(training_data):
                self.train_step(image, label)
            elapsed_time = time.time() - start

            for test_image, test_label in test_data:
                self.test_step(test_image, test_label)

            print(template.format(epoch + 1,
                                  self.train_loss.result(),
                                  self.train_accuracy.result() * 100,
                                  self.test_loss.result(),
                                  self.test_accuracy.result() * 100,
                                  elapsed_time))

EagerTrainerはEager executionのtrainerです。
コードを見るとだいたいどんな操作をしているかわかると思います。

KerasTrainer

trainer.py
class KerasTrainer(object):
    def __init__(self):
        self.model = MyModel()
        self.model.compile(optimizer=tf.keras.optimizers.Adam(),
                           loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                           metrics=['accuracy'])

    def train(self, epochs, training_data, test_data):
        self.model.fit(training_data, epochs=epochs, validation_data=test_data)

KerasTrainerEagerTrainerと全く同じ学習をするように書いています。
シンプルなモデルだと圧倒的にコードの量が少なくできるのでkerasの偉大さがわかります。

Training

main.py
import argparse

import tensorflow as tf
import tensorflow_datasets as tfds

from trainer import EagerTrainer, KerasTrainer


def convert_types(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label


def main():
    parser = argparse.ArgumentParser(description='Train Example')
    parser.add_argument('--trainer', type=str, default='eager')

    args = parser.parse_args()

    dataset, info = tfds.load('mnist', with_info=True, as_supervised=True)
    mnist_train, mnist_test = dataset['train'], dataset['test']

    mnist_train = mnist_train.map(convert_types).shuffle(10000).batch(32)
    mnist_test = mnist_test.map(convert_types).batch(32)

    if args.trainer.lower() == 'eager':
        trainer = EagerTrainer()
    else:
        trainer = KerasTrainer()
    trainer.train(epochs=5, training_data=mnist_train, test_data=mnist_test)


if __name__ == '__main__':
    main()

--trainerのオプションでどちらのTrainerを使うか選べるようにしています。
実際に実行してみるとCPU上では若干Eager executionの方が速い結果になりました。
Tensorflow 1系ではEager executionだと激おそになるという噂がありましたが、改善されているかもしれません。(GPUでもっと実用的なモデルを動かしてみないとわかりませんが)

実行結果(Eager execution)

$ python main.py 
2019-03-23 14:27:42.698962: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
0it [00:00, ?it/s]2019-03-23 14:27:43.063703: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875it [00:34, 54.75it/s]
Epoch 1, Loss: 0.13644, Accuracy: 95.91666, Test Loss: 0.06534, Test Accuracy: 97.75000, elapsed_time 34.25085
0it [00:00, ?it/s]2019-03-23 14:28:19.570435: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875it [00:32, 56.85it/s]
Epoch 2, Loss: 0.08982, Accuracy: 97.28416, Test Loss: 0.06335, Test Accuracy: 97.89000, elapsed_time 32.98253
0it [00:00, ?it/s]2019-03-23 14:28:54.483842: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875it [00:34, 75.77it/s]
Epoch 3, Loss: 0.06759, Accuracy: 97.94389, Test Loss: 0.06064, Test Accuracy: 98.03667, elapsed_time 34.02746
0it [00:00, ?it/s]2019-03-23 14:29:30.609697: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875it [00:34, 54.66it/s]
Epoch 4, Loss: 0.05404, Accuracy: 98.35125, Test Loss: 0.05985, Test Accuracy: 98.14000, elapsed_time 34.30574
0it [00:00, ?it/s]2019-03-23 14:30:06.853807: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875it [00:32, 70.69it/s]
Epoch 5, Loss: 0.04531, Accuracy: 98.61067, Test Loss: 0.05936, Test Accuracy: 98.18999, elapsed_time 32.76600

実行結果(keras)

$ python main.py --trainer keras
2019-03-23 14:35:22.782574: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
Epoch 1/5
2019-03-23 14:35:23.047235: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875/1875 [==============================] - 49s 26ms/step - loss: 0.1304 - accuracy: 0.9271 - val_loss: 0.0000e+00 - val_accuracy: 0.0000e+00
Epoch 2/5
2019-03-23 14:36:12.005254: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875/1875 [==============================] - 42s 23ms/step - loss: 0.0415 - accuracy: 0.9861 - val_loss: 0.0530 - val_accuracy: 0.9828
Epoch 3/5
2019-03-23 14:36:54.346357: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875/1875 [==============================] - 42s 22ms/step - loss: 0.0219 - accuracy: 0.9927 - val_loss: 0.0632 - val_accuracy: 0.9811
Epoch 4/5
2019-03-23 14:37:36.479987: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875/1875 [==============================] - 39s 21ms/step - loss: 0.0124 - accuracy: 0.9959 - val_loss: 0.0633 - val_accuracy: 0.9826
Epoch 5/5
2019-03-23 14:38:15.134248: W ./tensorflow/core/framework/model.h:202] Encountered a stop event that was not preceded by a start event.
1875/1875 [==============================] - 39s 21ms/step - loss: 0.0089 - accuracy: 0.9966 - val_loss: 0.0665 - val_accuracy: 0.9836
15
13
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?