はじめに
CNNのプログラムを実行していると、乱数に依存している箇所が多く、実行するたびに出力結果が変わる。これもあってか論文の完全再現が自分でも難しいことがある(小数点第2位オーダだと)。その問題に真面目に戦った話。
on gpuはこちら
検証環境
macOS High Sierra 10.13.6
python 3.5.5
pytorch 0.4.0
torchvision 0.2.1
検証
このプログラムを3回実行して結果を比較(長いので省略)
https://github.com/chatflip/qiita_code/blob/master/deterministic/not_deterministic1.py
~/qiita_code/deterministic$ python not_deterministic1.py
1回目 2回目 3回目
Train Epoch: 1 [0/60000 (0%)] Loss: 2.304123 Loss: 2.300169 Loss: 2.294914
Train Epoch: 1 [64/60000 (0%)] Loss: 2.306989 Loss: 2.302103 Loss: 2.290142
Train Epoch: 1 [128/60000 (0%)] Loss: 2.319033 Loss: 2.294176 Loss: 2.300677
Train Epoch: 1 [192/60000 (0%)] Loss: 2.305209 Loss: 2.313074 Loss: 2.310522
lossが実行するたびに違う値になりました。このnot_deterministic1.py
は多くの乱数が使われています。基本的には、乱数のseedを固定すれば何度実行しても同じ値になります。乱数を用いているところは以下の通り(他にもあるかもしれません…)
- 読み込んだ画像の変形(ex. transforms.RandomCrop)
- データセットのシャッフル(ex. DataLoaderの
shuffle=True
) - ネットワークの初期重み(ex. nn.Conv2d.weight)
これらの乱数のseedを固定するためにpytorch/torchvisionのライブラリを漁ります。
読み込んだ画像の変形
torchvision.transforms内で用いられている乱数はpythonの標準ライブラリのrandomです。
なのでrandom.seed()
を用いてseedを固定します。
transforms.RandomCropの乱数部分
https://github.com/pytorch/vision/blob/v0.2.1/torchvision/transforms/transforms.py#L399
transforms.RandomHorizontalFlipの乱数部分
https://github.com/pytorch/vision/blob/v0.2.1/torchvision/transforms/transforms.py#L447
pythonの標準ライブラリrandom
https://docs.python.jp/3.5/library/random.html
データセットのシャッフル
DataLoaderの引数のshuffleをTrueにすると関数の中ではRandomSamplerが呼ばれます。
RandomSamplerの中ではtorch.randperm
によってデータセットがシャッフルされます。
torch.randpermの乱数のseedはtorch.manual_seed()で設定可能です。
Random sampling creation ops are listed under [Random sampling] and ...
(https://pytorch.org/docs/0.4.0/torch.html#random-sampling)
https://pytorch.org/docs/0.4.0/torch.html?highlight=randperm#creation-ops
Dataloaderのshuffle=True
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/utils/data/dataloader.py#L434
RandomSamplerの乱数部分
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/utils/data/sampler.py#L51
torch.manual_seed
https://pytorch.org/docs/0.4.0/torch.html?highlight=randperm#torch.manual_seed
ネットワークの初期値
nn.Conv2dが宣言されるとself.weightにnn.Parameter()が格納され、torch.Tensorのuniform_(一様乱数)で初期化されます。
これもtorch.manual_seed()で設定可能です。
nn.ConvNdの重み設定
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/nn/modules/conv.py#L31
https://github.com/pytorch/pytorch/blob/v0.4.0/torch/nn/modules/conv.py#L45
再検証
以上の内容を踏まえたプログラム(not_deterministic2.py)を実行
~/qiita_code/deterministic$ python not_deterministic2.py --num_workers=0
1回目 2回目 3回目
Train Epoch: 1 [0/60000 (0%)] Loss: 2.283671 Loss: 2.283671 Loss: 2.283671
Train Epoch: 1 [64/60000 (0%)] Loss: 2.313809 Loss: 2.313809 Loss: 2.313809
Train Epoch: 1 [128/60000 (0%)] Loss: 2.318310 Loss: 2.318310 Loss: 2.318310
Train Epoch: 1 [192/60000 (0%)] Loss: 2.315323 Loss: 2.315323 Loss: 2.315323
結果が同じになった。しかし、num_workerを>0にすると
~/qiita_code/deterministic$ python not_deterministic2.py --num_workers=4
1回目 2回目 3回目
Train Epoch: 1 [0/60000 (0%)] Loss: 2.308435 Loss: 2.309057 Loss: 2.313945
Train Epoch: 1 [64/60000 (0%)] Loss: 2.309005 Loss: 2.315649 Loss: 2.306219
Train Epoch: 1 [128/60000 (0%)] Loss: 2.307885 Loss: 2.304023 Loss: 2.307097
Train Epoch: 1 [192/60000 (0%)] Loss: 2.314316 Loss: 2.316156 Loss: 2.308462
結果が変わる。この原因はDataLoaderの子プロセスの乱数のseedが固定されていないことが原因です。公式に書かれています。
By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See My data loader workers return identical random numbers section in FAQ.) You may use torch.initial_seed() to access the PyTorch seed for each worker in worker_init_fn, and use it to set other seeds before data loading.
torch.utils.data.DataLoader
https://pytorch.org/docs/0.4.0/data.html#torch.utils.data.DataLoader
なのでworker_init_fn
で乱数のseedを指定してあげましょう。
def worker_init_fn(worker_id):
random.seed(worker_id)
seedの指定にworker_idを使っているのは、子プロセスのseedが全部同じになることを防ぐためです。
再再検証
今まで出た変更点を考慮したプログラムを実行
~/qiita_code/deterministic$ python deterministic_cpu.py --num_workers=4
1回目 2回目 3回目
Train Epoch: 1 [0/60000 (0%)] Loss: 2.316236 Loss: 2.316236 Loss: 2.316236
Train Epoch: 1 [64/60000 (0%)] Loss: 2.310289 Loss: 2.310289 Loss: 2.310289
Train Epoch: 1 [128/60000 (0%)] Loss: 2.303744 Loss: 2.303744 Loss: 2.303744
Train Epoch: 1 [192/60000 (0%)] Loss: 2.311146 Loss: 2.311146 Loss: 2.311146
おわりに
pytorchでlossが毎回変わる問題の対処法は
-
random.seed()
とtorch.manual_seed()
を追加 -
torch.utils.data.DataLoader
のnum_workers>0
ならworker_init_fn
で子スレッドでrandom.seed()
を呼んでseedを固定
追記
以下のことを変えて実行すると結果は異なるので注意。
- batch_sizeの数の変更
- num_workerの数の変更
- pythonの系統の違い(python2.xとpython3.x)