LoginSignup
18
16

More than 5 years have passed since last update.

tensorflowで複数のGPUを使って学習する

Last updated at Posted at 2018-12-02

はじめに

意外と日本語の記事が見当たらなかったので書きました。
githubにソースコードをあげています。前回の記事と同じリポジトリです。
cifer10のexampleを参考にしています。

実装環境

  • Python 3.6.4
  • tensorflow 1.9.0

複数台のGPUで学習するには

tensorflowで複数台のGPUを使うときは、GPUごとにモデルを定義する必要があります。
そして、GPUごとに計算した勾配をまとめて1つのlossとして学習することを記述します。

    gradients = []
    for gpu_index in range(args.num_gpus):
        with tf.device('/gpu:%d' % gpu_index):
            with tf.name_scope('%s_%d' % ("gpu", gpu_index)) as scope:
                images, labels = custom_runner.get_inputs()
                train_inputs = {'x': images, 'y': labels}

                train_model_spec = model_fn_multigpu(train_inputs, reuse=True, is_train=True)

                tf.add_to_collection(tf.GraphKeys.LOSSES, train_model_spec['loss'])
                tf.add_to_collection(tf.GraphKeys.METRIC_VARIABLES, train_model_spec['accuracy'])
                losses = tf.get_collection(tf.GraphKeys.LOSSES, scope)
                accuracy = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES, scope)

                total_clone_loss = tf.add_n(losses)
                total_clone_accuracy = tf.add_n(accuracy)

                # compute clone gradients
                clone_gradients = optimizer.compute_gradients(total_clone_loss)
                gradients.append(clone_gradients)

    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    grads = average_gradients(gradients)

    # Apply the gradients to adjust the shared variables.
    apply_gradient_op = optimizer.apply_gradients(grads)

    # Group all updates to into a single train op.
    train_op = tf.group(apply_gradient_op)
    train_model_spec = {'train_op': train_op,
                        'loss': total_clone_loss,
                        'accuracy': total_clone_accuracy,
                        }

実行

train_multi_gpu.py--num_gpusのオプションをつけることでGPUの数を指定できます。

実行例

$ python train_multi_gpu.py --num_gpus 2
Epoch 1/10
100%|██████████| 937/937 [00:07<00:00, 118.05it/s, train_acc=0.871, train_loss=13.8]
train/acc: 0.8705, train/loss: 13.7522
valid/acc: 0.9402, valid/loss: 6.2843
Epoch 2/10
100%|██████████| 937/937 [00:06<00:00, 148.26it/s, train_acc=0.961, train_loss=4.02]
train/acc: 0.9610, train/loss: 4.0176
valid/acc: 0.9680, valid/loss: 3.3793
.
.
.

実行中にnvidia-smiコマンドを叩くと、2つのGPUが使われていることが確認できます。

$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla M60           On   | 00000000:00:1D.0 Off |                    0 |
| N/A   28C    P0    39W / 150W |   7365MiB /  7618MiB |     34%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M60           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    38W / 150W |   7365MiB /  7618MiB |     27%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     22375      C   python                                      7354MiB |
|    1     22375      C   python                                      7354MiB |
+-----------------------------------------------------------------------------+

GPUの指定を特にしないで動かすと、1つのGPUしか使われないことがわかります。

$ python train.py
Epoch 1/10
100%|██████████| 1875/1875 [00:24<00:00, 76.53it/s, train_acc=0.673, train_loss=33.3]
train/acc: 0.6735, train/loss: 33.3285
valid/acc: 0.8651, valid/loss: 13.2471
Epoch 2/10
100%|██████████| 1875/1875 [00:23<00:00, 80.22it/s, train_acc=0.898, train_loss=10.1]
train/acc: 0.8980, train/loss: 10.1478
valid/acc: 0.9133, valid/loss: 8.7852
.
.
.

$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 396.37                 Driver Version: 396.37                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla M60           On   | 00000000:00:1D.0 Off |                    0 |
| N/A   28C    P0    40W / 150W |   7365MiB /  7618MiB |     25%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M60           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    36W / 150W |   7245MiB /  7618MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     22698      C   python                                      7354MiB |
|    1     22698      C   python                                      7234MiB |
+-----------------------------------------------------------------------------+

終わり

kerasでは複数のGPUを使うことは簡単に実装できますが、tensorflowではlossを合わせるところなど地道に定義しないといけないことが多いのでめんどくさいですね。

18
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
16