3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tensorflow2.0の学習を、gradient tapeで書くと遅いと思ったがそんなことはなかった

Last updated at Posted at 2019-05-24

公式のコードを動かすだけで結果を比較してしまったが故に、gradient tapeが遅いと思いましたが、そんなことはありませんでした。ほとんど同じ速度です。

Kerasのfitで学習したMNIST: https://www.tensorflow.org/alpha/tutorials/quickstart/beginner
gradient tapeで学習したMNIST: https://www.tensorflow.org/alpha/tutorials/quickstart/advanced

結果

ほとんど同じでした。

Kerasのfit

Epoch 5/5
train: 60000/60000 [==============================] - 135s 2ms/sample - loss: 0.0032 - accuracy: 0.9990
test: 10000/10000 [==============================] - 6s 609us/sample - loss: 0.0890 - accuracy: 0.9823
CPU times: user 10min 35s, sys: 14.7 s, total: 10min 50s
Wall time: 11min 27s

gradient tape

Epoch 5, Loss: 0.025853322818875313, Accuracy: 99.20600128173828, Test Loss: 0.06388840824365616, Test Accuracy: 98.29166412353516
CPU times: user 22min 22s, sys: 31.7 s, total: 22min 54s
Wall time: 11min 47s

間違えた理由

  • 学習しているモデルとデータが微妙に違いました。
  • モデルとデータを揃えた結果、上記のようにほとんど同じ結果になりました。

参考:誤った結果(当初の間違え)

公式速度結果

  • kerasのfitの方が、約25倍速い。

Kerasのfit

CPU times: user 38.7 s, sys: 2.61 s, total: 41.3 s
Wall time: 30.8 s

gradient tape

CPU times: user 24min 29s, sys: 21.2 s, total: 24min 50s
Wall time: 12min 40s

その他の結果

Kerasのfit

  • code
model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test, y_test)
  • 精度
...
Epoch 5/5
60000/60000 [==============================] - 6s 100us/sample - loss: 0.0254 - accuracy: 0.9912
10000/10000 [==============================] - 0s 44us/sample - loss: 0.0842 - accuracy: 0.9812

gradient tape

  • code
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

@tf.function
def test_step(images, labels):
  predictions = model(images)
  t_loss = loss_object(labels, predictions)

  test_loss(t_loss)
  test_accuracy(labels, predictions)

EPOCHS = 5

for epoch in range(EPOCHS):
  for images, labels in train_ds:
    train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))
  • 精度
Epoch 5, Loss: 0.02371377870440483, Accuracy: 99.28133392333984, Test Loss: 0.07314498722553253, Test Accuracy: 98.22599792480469

その他メモ

Start of epoch 0
Start of epoch 1
Start of epoch 2
Start of epoch 0
Start of epoch 1
Start of epoch 2
train 6.240314280999883
Epoch 1/3
782/782 [==============================] - 3s 4ms/step - loss: 0.3294 - sparse_categorical_accuracy: 0.9058
Epoch 2/3
782/782 [==============================] - 3s 4ms/step - loss: 0.1495 - sparse_categorical_accuracy: 0.9556
Epoch 3/3
782/782 [==============================] - 3s 3ms/step - loss: 0.1093 - sparse_categorical_accuracy: 0.9672
fit 9.801731540999754
3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?