Help us understand the problem. What is going on with this article?

tensorflow2.0の学習を、gradient tapeで書くと遅いと思ったがそんなことはなかった

More than 1 year has passed since last update.

公式のコードを動かすだけで結果を比較してしまったが故に、gradient tapeが遅いと思いましたが、そんなことはありませんでした。ほとんど同じ速度です。

Kerasのfitで学習したMNIST: https://www.tensorflow.org/alpha/tutorials/quickstart/beginner
gradient tapeで学習したMNIST: https://www.tensorflow.org/alpha/tutorials/quickstart/advanced

結果

ほとんど同じでした。

Kerasのfit

Epoch 5/5
train: 60000/60000 [==============================] - 135s 2ms/sample - loss: 0.0032 - accuracy: 0.9990
test: 10000/10000 [==============================] - 6s 609us/sample - loss: 0.0890 - accuracy: 0.9823
CPU times: user 10min 35s, sys: 14.7 s, total: 10min 50s
Wall time: 11min 27s

gradient tape

Epoch 5, Loss: 0.025853322818875313, Accuracy: 99.20600128173828, Test Loss: 0.06388840824365616, Test Accuracy: 98.29166412353516
CPU times: user 22min 22s, sys: 31.7 s, total: 22min 54s
Wall time: 11min 47s

間違えた理由

  • 学習しているモデルとデータが微妙に違いました。
  • モデルとデータを揃えた結果、上記のようにほとんど同じ結果になりました。

参考:誤った結果(当初の間違え)

公式速度結果

  • kerasのfitの方が、約25倍速い。

Kerasのfit

CPU times: user 38.7 s, sys: 2.61 s, total: 41.3 s
Wall time: 30.8 s

gradient tape

CPU times: user 24min 29s, sys: 21.2 s, total: 24min 50s
Wall time: 12min 40s

その他の結果

Kerasのfit

  • code
model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test, y_test)
  • 精度
...
Epoch 5/5
60000/60000 [==============================] - 6s 100us/sample - loss: 0.0254 - accuracy: 0.9912
10000/10000 [==============================] - 0s 44us/sample - loss: 0.0842 - accuracy: 0.9812

gradient tape

  • code
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

@tf.function
def test_step(images, labels):
  predictions = model(images)
  t_loss = loss_object(labels, predictions)

  test_loss(t_loss)
  test_accuracy(labels, predictions)

EPOCHS = 5

for epoch in range(EPOCHS):
  for images, labels in train_ds:
    train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))
  • 精度
Epoch 5, Loss: 0.02371377870440483, Accuracy: 99.28133392333984, Test Loss: 0.07314498722553253, Test Accuracy: 98.22599792480469

その他メモ

Start of epoch 0
Start of epoch 1
Start of epoch 2
Start of epoch 0
Start of epoch 1
Start of epoch 2
train 6.240314280999883
Epoch 1/3
782/782 [==============================] - 3s 4ms/step - loss: 0.3294 - sparse_categorical_accuracy: 0.9058
Epoch 2/3
782/782 [==============================] - 3s 4ms/step - loss: 0.1495 - sparse_categorical_accuracy: 0.9556
Epoch 3/3
782/782 [==============================] - 3s 3ms/step - loss: 0.1093 - sparse_categorical_accuracy: 0.9672
fit 9.801731540999754
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした