4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Eager Executionとtf.kerasのTraining Loopをスムーズに使い分ける

Last updated at Posted at 2019-12-17

はじめに

TensorFlow2.0では、以下の2つの方法で学習を行うことができます。

  • Eager Executionによる学習
  • tf.kerasのTraining Loopによる学習

この記事では、2つの学習方法をスムーズに使い分けられるような書き方を紹介します。


TensorFlow2.0 Advent Calendar 2019に関連記事が数多くありますので、そちらも合わせてご覧ください。

そもそも使い分ける必要があるのか

正直なところ、全てEager Executionで書けばよいのでは。。。

個人的には、以下のように使い分けています。

  • Eager Execution:RNN, GAN
  • tf.kerasのTraining Loop:CNN

複雑なモデルはEager Executionで書くことが多いです。

使い分けるために

基本的に、最新のAPIを使えば共通化できます。

モデルをKeras Subclassing APIで書く

Kerasは以下の3つの形式でモデルを定義できます。

Model subclassingが最も柔軟にモデルを記述できるので、これを使います。

データのパイプラインをtf.dataで書く

Datasets APIでKerasのモデルを学習できるので、これを使います。

その他

その他の細かいところもできるだけ共通化しておくと便利です。

  • 重みの保存形式
  • ログの表示
    • 標準出力
    • TensorBoard

具体例(画像分類)

問題設定はここを参考にしました。
また、動作検証環境は以下の通りです。

  • Python 3.7
  • Anaconda 2019.10
  • TensorFlow 2.0.0

データの準備

画像ファイルとラベルが書かれたCSVファイルを準備します。1

以下のコードでは、MNISTのデータを画像とCSVファイルとして保存しています。

data_utils.py
from pathlib import Path

import pandas as pd
import tensorflow as tf
from PIL import Image


def save_data(data, image_dir, label_path):
    images, labels = data
    (image_dir).mkdir(parents=True, exist_ok=True)
    image_path_list = [image_dir / f"{i:04}.png" for i in range(len(images))]
    for image_path, image in zip(image_path_list, images):
        Image.fromarray(image).save(image_path)

    pd.DataFrame(zip([path.name for path in image_path_list], labels)).to_csv(
        label_path, header=False, index=False
    )


def read_csv(csv_path, image_dir):
    df = pd.read_csv(csv_path, header=None, names=["name", "label"])
    image_path_list = df["name"].apply(lambda x: str(image_dir / x)).to_list()
    labels = df["label"].values
    return image_path_list, labels


if __name__ == "__main__":
    train_data, test_data = tf.keras.datasets.mnist.load_data()

    data_dir = Path("input")
    save_data(train_data, data_dir / "train_images", data_dir / "train_labels.csv")
    save_data(test_data, data_dir / "test_images", data_dir / "test_labels.csv")

共通部分

前述のモデル定義とデータのパイプライン部分です。

common.py
import numpy as np
import tensorflow as tf


class Model(tf.keras.Model):
    def __init__(self, num_classes):
        super(Model, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation="relu")
        self.conv2 = tf.keras.layers.Conv2D(64, 3, activation="relu")
        self.pool1 = tf.keras.layers.MaxPooling2D()
        self.dp1 = tf.keras.layers.Dropout(0.25)
        self.fc1 = tf.keras.layers.Dense(128, activation="relu")
        self.dp2 = tf.keras.layers.Dropout(0.5)
        self.fc2 = tf.keras.layers.Dense(num_classes, activation="softmax")

    @tf.function
    def call(self, x, training=False):
        h = self.conv1(x)
        h = self.conv2(h)
        h = self.pool1(h)
        h = tf.keras.layers.Flatten()(h)
        h = self.dp1(h, training=training)
        h = self.fc1(h)
        h = self.dp2(h, training=training)
        h = self.fc2(h)
        return h


def load_image_and_label(image_path, label):
    image = tf.io.decode_image(tf.io.read_file(image_path), channels=1)
    image = tf.reshape(image, [28, 28, 1])
    image = tf.cast(image, "float32") / 255

    label = tf.cast(label, "float32")
    return image, label


def get_data_loader(image_path_list, labels, batch_size, training):
    assert len(image_path_list) == len(labels)
    num_samples = len(image_path_list)

    ds = tf.data.Dataset.from_tensor_slices((image_path_list, labels))
    ds = ds.map(
        lambda image_path, label: load_image_and_label(image_path, label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    if training:
        ds = ds.shuffle(buffer_size=num_samples)
    ds = ds.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return ds, int(np.ceil(num_samples / batch_size))

Eager Execution

Eager Executionのサンプルです。後述のtf.kerasのTraining Loopのサンプルと対応している箇所が見やすいように書いたつもりです。

from pathlib import Path

import tensorflow as tf
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from common import Model, get_data_loader
from data_utils import read_csv


def run(epoch, model, optimizer, ds, len_ds, summary_writer, metrics_fn, training):
    prefix = "val_" if not training else ""
    total_loss = 0

    with tqdm(ds, total=len_ds, disable=not training) as pbar:
        for count, (images, labels) in enumerate(pbar, start=1):
            if training:
                with tf.GradientTape() as tape:
                    predictions = model(images, training=training)
                    loss = loss_fn(labels, predictions)
                grad = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grad, model.trainable_variables))
            else:
                predictions = model(images, training=training)
                loss = loss_fn(labels, predictions)
            total_loss += loss.numpy()
            metrics_fn.update_state(labels, predictions)

            postfix_str = (
                f"{prefix}loss: {total_loss / count:.4f}, "
                + f"{prefix}accuracy: {metrics_fn.result().numpy():.4f}"
            )
            pbar.set_postfix_str(postfix_str)

    if not training:
        print(postfix_str)

    with summary_writer.as_default():
        tf.summary.scalar(f"epoch_loss", total_loss / len_ds, step=epoch)
        tf.summary.scalar(f"epoch_accuracy", metrics_fn.result().numpy(), step=epoch)
    metrics_fn.reset_states()


if __name__ == "__main__":
    num_classes = 10
    batch_size = 128
    epochs = 12
    data_dir = Path("input")
    checkpoint_dir = Path("checkpoint")
    tensorboard_dir = Path("tensorboard")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    image_path_list, labels = read_csv(
        data_dir / "train_labels.csv", data_dir / "train_images"
    )
    (
        train_image_path_list,
        valid_image_path_list,
        train_labels,
        valid_labels,
    ) = train_test_split(image_path_list, labels, test_size=0.2, random_state=42)

    train_ds, len_train_ds = get_data_loader(
        train_image_path_list, train_labels, batch_size, True
    )
    valid_ds, len_valid_ds = get_data_loader(
        valid_image_path_list, valid_labels, batch_size, False
    )

    model = Model(num_classes)

    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adadelta()
    metrics_fn = tf.keras.metrics.SparseCategoricalAccuracy()

    train_summary_writer = tf.summary.create_file_writer(str(tensorboard_dir / "train"))
    valid_summary_writer = tf.summary.create_file_writer(
        str(tensorboard_dir / "validation")
    )

    checkpoint_prefix = checkpoint_dir / "ckpt"
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)

    # 学習
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        run(
            epoch,
            model,
            optimizer,
            train_ds,
            len_train_ds,
            train_summary_writer,
            metrics_fn,
            True,
        )
        run(
            epoch,
            model,
            optimizer,
            valid_ds,
            len_valid_ds,
            valid_summary_writer,
            metrics_fn,
            False,
        )
        checkpoint.save(file_prefix=checkpoint_prefix)

    # 評価
    model = Model(num_classes)
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()

    test_image_path_list, test_labels = read_csv(
        data_dir / "test_labels.csv", data_dir / "test_images"
    )
    test_ds, len_test_ds = get_data_loader(
        test_image_path_list, test_labels, batch_size, False
    )
    predictions = model.predict(test_ds).argmax(axis=-1)

ターミナルの出力はこんな感じです。

Epoch 12/12
100%|████████████████████████████████████████████████| 375/375 [00:03<00:00, 111.13it/s, loss: 1.1543, accuracy: 0.6771]
val_loss: 0.9472, val_accuracy: 0.7996

tf.kerasのTraining Loop

tf.kerasのTraining Loopのサンプルです。前述のEager Executionのサンプルと対応している箇所が見やすいように書いたつもりです。

from pathlib import Path

import tensorflow as tf
from sklearn.model_selection import train_test_split

from common import Model, get_data_loader
from data_utils import read_csv

if __name__ == "__main__":
    num_classes = 10
    batch_size = 128
    epochs = 12
    data_dir = Path("input")
    checkpoint_dir = Path("checkpoint-keras")
    tensorboard_dir = Path("tensorboard-keras")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    image_path_list, labels = read_csv(
        data_dir / "train_labels.csv", data_dir / "train_images"
    )
    (
        train_image_path_list,
        valid_image_path_list,
        train_labels,
        valid_labels,
    ) = train_test_split(image_path_list, labels, test_size=0.2, random_state=42)

    train_ds, len_train_ds = get_data_loader(
        train_image_path_list, train_labels, batch_size, True
    )
    valid_ds, len_valid_ds = get_data_loader(
        valid_image_path_list, valid_labels, batch_size, False
    )

    model = Model(num_classes)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(name="loss"),
        optimizer=tf.keras.optimizers.Adadelta(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
    )
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            filepath=str(checkpoint_dir / "ckpt-{epoch}"), save_weights_only=True
        ),
        tf.keras.callbacks.TensorBoard(log_dir=str(tensorboard_dir), profile_batch=0),
    ]

    # 学習
    model.fit(train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds)

    # 評価
    model = Model(num_classes)
    model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

    test_image_path_list, test_labels = read_csv(
        data_dir / "test_labels.csv", data_dir / "test_images"
    )
    test_ds, len_test_ds = get_data_loader(
        test_image_path_list, test_labels, batch_size, False
    )
    predictions = model.predict(test_ds).argmax(axis=-1)

ターミナルの出力はこんな感じです。

Epoch 12/12
370/375 [============================>.] - ETA: 0s - loss: 1.1553 - accuracy: 0.66702019-12-17 22:52:11.147062: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
         [[{{node IteratorGetNext}}]]
         [[IteratorGetNext/_4]]
2019-12-17 22:52:11.147216: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence                                             [[{{node IteratorGetNext}}]]
375/375 [==============================] - 4s 10ms/step - loss: 1.1548 - accuracy: 0.6671 - val_loss: 0.9395 - val_accuracy: 0.8041

Warningが出ていますが、TensorFlow 2.1.0-rc1では修正されているようです。

最後に

ベストプラクティスがあれば教えてください:bow:

  1. TensorFlow Datasetsが便利ですが、実応用を想定して面倒なことを行っています。

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?