42
36

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

pytorchでCNNのlossが毎回変わる問題の対処法 (on cpu)

Last updated at Posted at 2018-08-13

はじめに

CNNのプログラムを実行していると、乱数に依存している箇所が多く、実行するたびに出力結果が変わる。これもあってか論文の完全再現が自分でも難しいことがある(小数点第2位オーダだと)。その問題に真面目に戦った話。

on gpuはこちら

検証環境

macOS High Sierra 10.13.6
python 3.5.5
pytorch 0.4.0
torchvision 0.2.1

検証

このプログラムを3回実行して結果を比較(長いので省略)
https://github.com/chatflip/qiita_code/blob/master/deterministic/not_deterministic1.py

terminal
~/qiita_code/deterministic$ python not_deterministic1.py
not_deterministic1.py
							      	1回目           2回目            3回目
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.304123  Loss: 2.300169  Loss: 2.294914
Train Epoch: 1 [64/60000 (0%)]	Loss: 2.306989  Loss: 2.302103  Loss: 2.290142
Train Epoch: 1 [128/60000 (0%)]	Loss: 2.319033  Loss: 2.294176  Loss: 2.300677
Train Epoch: 1 [192/60000 (0%)]	Loss: 2.305209  Loss: 2.313074  Loss: 2.310522

lossが実行するたびに違う値になりました。このnot_deterministic1.pyは多くの乱数が使われています。基本的には、乱数のseedを固定すれば何度実行しても同じ値になります。乱数を用いているところは以下の通り(他にもあるかもしれません…)

  • 読み込んだ画像の変形(ex. transforms.RandomCrop)
  • データセットのシャッフル(ex. DataLoaderのshuffle=True)
  • ネットワークの初期重み(ex. nn.Conv2d.weight)

これらの乱数のseedを固定するためにpytorch/torchvisionのライブラリを漁ります。

読み込んだ画像の変形

torchvision.transforms内で用いられている乱数はpythonの標準ライブラリのrandomです。
なのでrandom.seed()を用いてseedを固定します。

transforms.RandomCropの乱数部分
https://github.com/pytorch/vision/blob/v0.2.1/torchvision/transforms/transforms.py#L399
transforms.RandomHorizontalFlipの乱数部分
https://github.com/pytorch/vision/blob/v0.2.1/torchvision/transforms/transforms.py#L447
pythonの標準ライブラリrandom
https://docs.python.jp/3.5/library/random.html

データセットのシャッフル

DataLoaderの引数のshuffleをTrueにすると関数の中ではRandomSamplerが呼ばれます。
RandomSamplerの中ではtorch.randpermによってデータセットがシャッフルされます。
torch.randpermの乱数のseedはtorch.manual_seed()で設定可能です。

Random sampling creation ops are listed under [Random sampling] and ...
(https://pytorch.org/docs/0.4.0/torch.html#random-sampling)
https://pytorch.org/docs/0.4.0/torch.html?highlight=randperm#creation-ops

Dataloaderのshuffle=True
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/utils/data/dataloader.py#L434
RandomSamplerの乱数部分
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/utils/data/sampler.py#L51
torch.manual_seed
https://pytorch.org/docs/0.4.0/torch.html?highlight=randperm#torch.manual_seed

ネットワークの初期値

nn.Conv2dが宣言されるとself.weightにnn.Parameter()が格納され、torch.Tensorのuniform_(一様乱数)で初期化されます。
これもtorch.manual_seed()で設定可能です。

nn.ConvNdの重み設定
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/nn/modules/conv.py#L31
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/nn/modules/conv.py#L45

再検証

以上の内容を踏まえたプログラム(not_deterministic2.py)を実行

terminal
~/qiita_code/deterministic$ python not_deterministic2.py --num_workers=0
not_deterministic2.py
							      	1回目           2回目            3回目
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.283671	Loss: 2.283671	Loss: 2.283671
Train Epoch: 1 [64/60000 (0%)]	Loss: 2.313809	Loss: 2.313809	Loss: 2.313809
Train Epoch: 1 [128/60000 (0%)]	Loss: 2.318310	Loss: 2.318310	Loss: 2.318310
Train Epoch: 1 [192/60000 (0%)]	Loss: 2.315323	Loss: 2.315323	Loss: 2.315323

結果が同じになった。しかし、num_workerを>0にすると

terminal
~/qiita_code/deterministic$ python not_deterministic2.py --num_workers=4
not_deterministic2.py
							      	1回目           2回目            3回目
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.308435	Loss: 2.309057	Loss: 2.313945
Train Epoch: 1 [64/60000 (0%)]	Loss: 2.309005	Loss: 2.315649	Loss: 2.306219
Train Epoch: 1 [128/60000 (0%)]	Loss: 2.307885	Loss: 2.304023	Loss: 2.307097
Train Epoch: 1 [192/60000 (0%)]	Loss: 2.314316	Loss: 2.316156	Loss: 2.308462

結果が変わる。この原因はDataLoaderの子プロセスの乱数のseedが固定されていないことが原因です。公式に書かれています。

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See My data loader workers return identical random numbers section in FAQ.) You may use torch.initial_seed() to access the PyTorch seed for each worker in worker_init_fn, and use it to set other seeds before data loading.

torch.utils.data.DataLoader
https://pytorch.org/docs/0.4.0/data.html#torch.utils.data.DataLoader

なのでworker_init_fnで乱数のseedを指定してあげましょう。

deterministic.py(追加部分)
def worker_init_fn(worker_id):
    random.seed(worker_id)

seedの指定にworker_idを使っているのは、子プロセスのseedが全部同じになることを防ぐためです。

再再検証

今まで出た変更点を考慮したプログラムを実行

terminal
~/qiita_code/deterministic$ python deterministic_cpu.py --num_workers=4
deterministic_cpu.py
							      	1回目           2回目            3回目
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.316236	Loss: 2.316236	Loss: 2.316236
Train Epoch: 1 [64/60000 (0%)]	Loss: 2.310289	Loss: 2.310289	Loss: 2.310289
Train Epoch: 1 [128/60000 (0%)]	Loss: 2.303744	Loss: 2.303744	Loss: 2.303744
Train Epoch: 1 [192/60000 (0%)]	Loss: 2.311146	Loss: 2.311146	Loss: 2.311146

おわりに

pytorchでlossが毎回変わる問題の対処法は

  • random.seed()torch.manual_seed()を追加
  • torch.utils.data.DataLoadernum_workers>0ならworker_init_fnで子スレッドでrandom.seed()を呼んでseedを固定

追記

以下のことを変えて実行すると結果は異なるので注意。

  • batch_sizeの数の変更
  • num_workerの数の変更
  • pythonの系統の違い(python2.xとpython3.x)
42
36
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
42
36

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?