Help us understand the problem. What is going on with this article?

tensorflowで複数のGPUを使って学習する

More than 1 year has passed since last update.

はじめに

意外と日本語の記事が見当たらなかったので書きました。
githubにソースコードをあげています。前回の記事と同じリポジトリです。
cifer10のexampleを参考にしています。

実装環境

  • Python 3.6.4
  • tensorflow 1.9.0

複数台のGPUで学習するには

tensorflowで複数台のGPUを使うときは、GPUごとにモデルを定義する必要があります。
そして、GPUごとに計算した勾配をまとめて1つのlossとして学習することを記述します。

    gradients = []
    for gpu_index in range(args.num_gpus):
        with tf.device('/gpu:%d' % gpu_index):
            with tf.name_scope('%s_%d' % ("gpu", gpu_index)) as scope:
                images, labels = custom_runner.get_inputs()
                train_inputs = {'x': images, 'y': labels}

                train_model_spec = model_fn_multigpu(train_inputs, reuse=True, is_train=True)

                tf.add_to_collection(tf.GraphKeys.LOSSES, train_model_spec['loss'])
                tf.add_to_collection(tf.GraphKeys.METRIC_VARIABLES, train_model_spec['accuracy'])
                losses = tf.get_collection(tf.GraphKeys.LOSSES, scope)
                accuracy = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES, scope)

                total_clone_loss = tf.add_n(losses)
                total_clone_accuracy = tf.add_n(accuracy)

                # compute clone gradients
                clone_gradients = optimizer.compute_gradients(total_clone_loss)
                gradients.append(clone_gradients)

    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    grads = average_gradients(gradients)

    # Apply the gradients to adjust the shared variables.
    apply_gradient_op = optimizer.apply_gradients(grads)

    # Group all updates to into a single train op.
    train_op = tf.group(apply_gradient_op)
    train_model_spec = {'train_op': train_op,
                        'loss': total_clone_loss,
                        'accuracy': total_clone_accuracy,
                        }

実行

train_multi_gpu.py--num_gpusのオプションをつけることでGPUの数を指定できます。

実行例

$ python train_multi_gpu.py --num_gpus 2
Epoch 1/10
100%|██████████| 937/937 [00:07<00:00, 118.05it/s, train_acc=0.871, train_loss=13.8]
train/acc: 0.8705, train/loss: 13.7522
valid/acc: 0.9402, valid/loss: 6.2843
Epoch 2/10
100%|██████████| 937/937 [00:06<00:00, 148.26it/s, train_acc=0.961, train_loss=4.02]
train/acc: 0.9610, train/loss: 4.0176
valid/acc: 0.9680, valid/loss: 3.3793
.
.
.

実行中にnvidia-smiコマンドを叩くと、2つのGPUが使われていることが確認できます。

$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla M60           On   | 00000000:00:1D.0 Off |                    0 |
| N/A   28C    P0    39W / 150W |   7365MiB /  7618MiB |     34%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M60           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    38W / 150W |   7365MiB /  7618MiB |     27%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     22375      C   python                                      7354MiB |
|    1     22375      C   python                                      7354MiB |
+-----------------------------------------------------------------------------+

GPUの指定を特にしないで動かすと、1つのGPUしか使われないことがわかります。

$ python train.py
Epoch 1/10
100%|██████████| 1875/1875 [00:24<00:00, 76.53it/s, train_acc=0.673, train_loss=33.3]
train/acc: 0.6735, train/loss: 33.3285
valid/acc: 0.8651, valid/loss: 13.2471
Epoch 2/10
100%|██████████| 1875/1875 [00:23<00:00, 80.22it/s, train_acc=0.898, train_loss=10.1]
train/acc: 0.8980, train/loss: 10.1478
valid/acc: 0.9133, valid/loss: 8.7852
.
.
.

$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla M60           On   | 00000000:00:1D.0 Off |                    0 |
| N/A   28C    P0    40W / 150W |   7365MiB /  7618MiB |     25%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M60           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    36W / 150W |   7245MiB /  7618MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     22698      C   python                                      7354MiB |
|    1     22698      C   python                                      7234MiB |
+-----------------------------------------------------------------------------+

終わり

kerasでは複数のGPUを使うことは簡単に実装できますが、tensorflowではlossを合わせるところなど地道に定義しないといけないことが多いのでめんどくさいですね。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away