3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

乱数固定の落とし穴(Python, PyTorch)

Last updated at Posted at 2023-12-03

直面した問題

  1. PyTorchで実装したある学習モデルを、他の学習スクリプトでも使おうと流用。
  2. モデルが正常に機能しているか確かめるためにSEEDを固定し学習してみると、モデルの初期値が移植前のものと異なることが判明。
  3. モデルのパラメータ等が同じになるよう見直すなどしてみるも、同じ初期値で初期化されない。

原因

  • SEEDを固定しても乱数を使うタイミング(固定してから乱数を使う回数)が異なれば異なる乱数が得られます。
  • 固定したい乱数(今回、モデルの初期値)の生成直前でSEEDを固定する必要があります。

解説

前提知識

Pythonの乱数固定はモジュールごとに行う必要があります。
PyTorchの環境では一般的に以下のようにPythonの組み込みrandomモジュール、numpy、PyTorchのSEEDを固定すれば良いと思います。

import random
import numpy as np
import torch

def fix_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

下記記事参考にさせて頂きました。

補足 2023/12/05
numpyの乱数固定については通常globalなnp.random.seed()よりもnp.random.RandomState()を使う方が良いというご指摘を頂きましたので補足致します。

乱数の挙動について

乱数の挙動の詳細を見ていきます。
ramdomモジュールを例にしていますが、numpy、PyTorch共に考え方は同じと思われます。

import random

a = [random.randint(0, 9) for i in range(5)]
print(a) # -> [9, 4, 9, 0, 5] (実行例、実行の度に結果は変わる)

このコードでは、0~9の範囲の長さ5の乱数列が生成されます。
乱数なので当然、実行の度に結果は変わります。

では次にSEEDを固定してみます。

import random

random.seed(0)
a = [random.randint(0, 9) for i in range(5)]
print(a) # -> [6, 6, 0, 4, 8] (実行例、何度実行しても結果は同じ)

SEEDを固定すれば、何度実行しても同じ乱数列が得られます。
ここまでは問題ありません。

では次に乱数列を2つ作ってみます。
例では同じことを2回繰り返しているだけです。

import random

random.seed(0)

a = [random.randint(0, 9) for i in range(5)]
print(a) # -> [6, 6, 0, 4, 8] (何度実行しても結果は同じ)

b = [random.randint(0, 9) for i in range(5)]
print (b) # -> [7, 6, 4, 7, 5] (何度実行しても結果は同じ)

乱数列aとbは何度実行しても同じなのですが、aとbは違う乱数列になっています。
aとbを同じ乱数列にするためには、bの乱数を得る前にもう一度SEEDを固定します。

import random

random.seed(0)

a = [random.randint(0, 9) for i in range(10)]
print(a) # -> [6, 6, 0, 4, 8]

random.seed(0) # これを追加

b = [random.randint(0, 9) for i in range(5)]
print (b) # -> [6, 6, 0, 4, 8]

なお以下のように乱数列を10個生成するようにしてみると、結果はaとbを連結したものになっていることが分かります。

import random

random.seed(0)

a = [random.randint(0, 9) for i in range(10)]
print(a) # -> [6, 6, 0, 4, 8, 7, 6, 4, 7, 5] (aとbを連結した形)

乱数は呼ばれる度にランダム(に見える)値を返す仕組みですが、SEEDの固定とはその数列をリセットするための操作となるため、同じ数列を得たい場合にはリセットしてから何回乱数が生成されているかの回数を同じにしなければなりません。
しかし実際には回数を同じにするより、必要な場面で毎回SEEDを固定し直すのが簡単です。

今回のトラブルの原因

今回は以下のようなコードを使っていました。

import random
import numpy as np
import torch

fix_seeds(0) # とりあえずプログラムの先頭あたりでfixするのが一般的なのでそうしていた

# ~いろいろな初期化処理等~

model = MyModel()

# ~学習処理~

def fix_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

~いろいろな初期化処理等~
が移植前と移植先のコードで異なっており、その中で恐らく乱数が使用されていたため、モデルの初期化に使われる乱数列が変化したためこのようなトラブルが起こっていました。
プログラムの先頭だけではなく、以下のように固定したい対象(今回の場合モデルの初期化処理)の直前でSEEDを再固定することで解決しました。

fix_seeds(0) # モデルの初期化直前で再固定
model = MyModel()

まとめ

  • 同じモデルや自作のクラス/関数を流用しても、利用する側のコードが異なれば乱数が変わる可能性があります。
  • 乱数のSEED固定を行うと、生成される乱数列が初期化されます。
  • 同じ数列を確実に得たい場合、その度にSEEDを固定するのが安全です。

とは言え、コードが大きくなると見逃す箇所も出てくると思います。
PyTorchのようにフレームワークを使う場合、関数の中で知らぬ間に乱数が使われているなどということもありそうです。
乱数の固定はいわゆる再現性の確認に使われると思いますが、そこばかりに注力しても無駄な労力を生じる可能性があるので、ほどほどにしておいた方が良いのかもしれません。

3
1
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?