12
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tensorflowで高速に学習するための実装

Last updated at Posted at 2018-11-18

はじめに

deep learningの学習を高速にするためには色々な側面があると思いますが、この記事で扱う内容はinput-pipelineの話です。

もともとDNNの学習は非常に時間がかかるものですが、GPUの性能が飛躍的に向上する中で学習データのI/Oやdata augmentationのための前処理部分がボトルネックとなってしまい学習が遅くなることがあります。せっかくVoltaのような高性能なGPUを使っていてもデータ生成で時間がかかってしまい、GPUをフルに使えないのは非常にもったいないです。

今回は学習データを並列に生成しキューに溜め込むことでデータ生成を高速にする実装を紹介します。TFRecordなどを使ったtensorflow独自のデータ形式を用意しなくて良いコードになっているのでtensorflowに慣れていない方もとっつきやすいと思います。

以下2つの記事の内容を参考にしています。ソースコードはgithubに載せています。
TensorFlow Data Input (Part 2): Extensions & Hacks
Building a data pipeline

実装環境

  • Python 3.6.4
  • tensorflow 1.9.0

input-pipelineの全体像

tensorflowで一番最初にやる実装方法はinputのテンソルをtf.placeholderで定義して、feed_dictでミニバッチ分の学習データを流し込んで学習していくようなコードを書くと思います。しかしこの形式で学習しているとデータの生成部分と学習や重みの更新部分を並列で動作させることができません。

この記事で紹介する実装では、学習部分ではマルチプロセスにデータを生成し入力データを貯めるキューを定義して、epochごとの検証ではtf.placeholderからデータを流し込むという実装になっています。

扱うタスクはMNISTでモデルは単純なCNNです。

重みの共有

この実装では、学習部分と検証部分で別々にモデルを定義する必要があります。学習と検証の2つのモデルを定義する時に2回目のネットワークをreuse=Trueと指定すると2つのネットワークで重みを共有することができます。

ちょっとしたはまりどころとして、tf.keras.layersでlayerを定義するとその度にインスタンスが生成され、重みを共有することができないのでreuseを使って重みを共有する時は気をつけてください。

train.py
    train_model_spec = model_fn(train_inputs, is_train=True)
    valid_model_spec = model_fn(valid_inputs, reuse=True, is_train=False)
model.py
def build_model(inputs):
    x = tf.layers.conv2d(inputs, 4, (3, 3), activation='relu', padding='same')
    x = tf.layers.batch_normalization(x)
    x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))

    x = tf.layers.conv2d(x, 8, (3, 3), activation='relu', padding='same')
    x = tf.layers.batch_normalization(x)
    x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))

    x = tf.layers.conv2d(x, 8, (3, 3), activation='relu', padding='same')
    x = tf.layers.batch_normalization(x)
    x = tf.layers.conv2d(x, 8, (3, 3), activation='relu', padding='same')
    x = tf.layers.batch_normalization(x)
    x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))

    x = Flatten(x)

    x = tf.layers.dense(x, 64, activation='relu')
    softmax = tf.layers.dense(x, 10, activation='softmax')
    return softmax


def model_fn(inputs, reuse=False, is_train=True):
    if 'x' not in inputs:
        ValueError('x is nothing')
    if is_train and 'y' not in inputs:
        ValueError('even training mode, y is nothing')

    with tf.variable_scope('model', reuse=reuse):
        softmax = build_model(inputs['x'])
    model_spec = inputs
    model_spec['softmax'] = softmax
    if 'y' in inputs:
        cross_entropy_loss = -tf.reduce_sum(inputs['y'] * tf.log(softmax))
        if is_train:
            train_op = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy_loss)
            model_spec['train_op'] = train_op
        correct_prediction = tf.equal(tf.argmax(softmax, 1), tf.argmax(inputs['y'], 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

        tf.summary.scalar('loss', cross_entropy_loss)
        tf.summary.scalar('acc', accuracy)
        summary_op = tf.summary.merge_all()

        model_spec['loss'] = cross_entropy_loss
        model_spec['accuracy'] = accuracy
        model_spec['summary_op'] = summary_op
    return model_spec

##キューの定義
学習データを並列に生成しながら、キューに溜め込んでいくためのクラスCustomRunner を定義し、学習データを生成するiteratorを引数に与えます。tf.Session()を立ち上げた時にstart_threads関数でスレッドの数を指定し、データをキューに溜め込む準備ができます。

train.py
    mnist_generator = MnistDataGenerator(args.batch_size)
    custom_runner = CustomRunner(input_shape, num_classes, args.batch_size, mnist_generator.train_iterator)

    images, labels = custom_runner.get_inputs()
    train_inputs = {'x': images, 'y': labels}
    valid_inputs = {'x': tf.placeholder(tf.float32, [None, ] + input_shape),
                    'y': tf.placeholder(tf.float32, [None, num_classes])}

    train_model_spec = model_fn(train_inputs, is_train=True)
    valid_model_spec = model_fn(valid_inputs, reuse=True, is_train=False)

data_generator.py

class CustomRunner(object):
    """
    This class manages the the background threads needed to fill
        a queue full of data.
    """
    def __init__(self, input_shape, num_classes, batch_size, iterator):
        self.dataX = tf.placeholder(dtype=tf.float32, shape=[None, ] + input_shape)
        self.dataY = tf.placeholder(dtype=tf.float32, shape=[None, num_classes])
        self.batch_size = batch_size
        # The actual queue of data. The queue contains a vector for
        self.queue = tf.RandomShuffleQueue(shapes=[input_shape, [num_classes]],
                                           dtypes=[tf.float32, tf.float32],
                                           capacity=2000,
                                           min_after_dequeue=1000)
        # The symbolic operation to add data to the queue
        # we could do some preprocessing here or do it in numpy. In this example
        # we do the scaling in numpy
        self.iterator = iterator
        self.enqueue_op = self.queue.enqueue_many([self.dataX, self.dataY])

    def get_inputs(self):
        """
        Return's tensors containing a batch of images and labels
        """
        images_batch, labels_batch = self.queue.dequeue_many(self.batch_size)
        return images_batch, labels_batch

    def thread_main(self, sess):
        """
        Function run on alternate thread. Basically, keep adding data to the queue.
        """
        for dataX, dataY in self.iterator():
            sess.run(self.enqueue_op, feed_dict={self.dataX: dataX, self.dataY: dataY})

    def start_threads(self, sess, n_threads=1):
        """ Start background threads to feed queue """
        threads = []
        for n in range(n_threads):
            t = threading.Thread(target=self.thread_main, args=(sess,))
            t.daemon = True  # thread will close when parent quits
            t.start()
            threads.append(t)
        return threads

学習データ生成クラス

サンプルコードではMNISTを扱うので学習データ全てがメモリに乗りますし、data augmentationもしていないのでデータ生成部分は非常に軽いです。なので、データ生成が重い環境を再現するために各バッチを生成する際にsleepさせています。

data_generator.py
class MnistDataGenerator(object):
    def __init__(self, batch_size):
        self.batch_size = batch_size
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        self.x_train = np.expand_dims(x_train, axis=-1) / 255.
        self.x_test = np.expand_dims(x_test, axis=-1) / 255.

        self.y_train = tf.keras.utils.to_categorical(y_train, 10)
        self.y_test = tf.keras.utils.to_categorical(y_test, 10)
        self.num_train_sample = len(self.x_train)

    def train_iterator(self):
        while True:
            i = random.randint(0, len(self.x_train) - self.batch_size)
            batch_x = self.x_train[i:i + self.batch_size]
            batch_y = self.y_train[i:i + self.batch_size]

            sleep(0.05)  # データ生成で何かしらの重い処理があることを仮定
            yield batch_x, batch_y

    def test_iterator(self):
        for i in range(0, len(self.x_test), self.batch_size):
            batch_x = self.x_train[i:i + self.batch_size]
            batch_y = self.y_train[i:i + self.batch_size]
            yield batch_x, batch_y

実行

今回手元に手頃なGPUがないのでCPU上で動かしてみました。

thread=1の場合

$ python train.py --n_thread 1
Epoch 1/10
100%|█████████████████████████| 1875/1875 [01:43<00:00, 18.13it/s, train_acc=0.687, train_loss=31.8]
train/acc: 0.6868, train/loss: 31.8274
valid/acc: 0.8502, valid/loss: 14.6315
.
.
.

thread=4の場合

$ python train.py --n_thread 4
Epoch 1/10
100%|█████████████████████████| 1875/1875 [00:26<00:00, 69.64it/s, train_acc=0.668, train_loss=33.4]
train/acc: 0.6685, train/loss: 33.4167
valid/acc: 0.8849, valid/loss: 12.0425
.
.
.

thread=1では1epochあたり01:43、thread=4では1epochあたり00:26でした。マルチスレッドでデータ生成することで明らかに学習スピードが早くなることがわかります。

終わり

DNNを学習していて、最新のGPUを使っても古いGPUを使っていてもあまり学習時間が変わらないということが時々起こります。それは私の経験上データ生成部分がボトルネックとなっていることが多いです

GPUで動かしている時は特にリソースをフルに活用できないことになります。

学習がなぜか遅いなと思った時は、input-pipelineを見直してみましょう。

12
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?