3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

深層学習、機械学習のシード固定のはまりポイント紹介

Last updated at Posted at 2022-03-08

初めに

機械学習や深層学習をしていると学習を再現させるためにシードを固定する必要が出てくると思います。
私もその必要があり、機械学習におけるランダムシードの研究を参考にしてシードを固定させていたのですが、私の環境だとどうにもうまく再現できませんでした。
同じように困っている方がいるのではないかと思い記事にしました。

原因

アルゴリズム内での計算による誤差が原因で特定の数値に影響を及ぼしていました。
私が使用したツール(アルゴリズム)はscikit-learnのKMeansだったのですが、predictによって得られるクラスターの予測値については全く影響ありません。
影響があったのはKMeansでパラメータ計算(fit)後のcluster_centers_の値です。
※誤差の値は非常に小さい値だったのでクラスター番号の予測には影響しなかったと思われます。

発見までの流れ

なぜこの問題に気付いたかの流れです。

  1. シードを固定して深層学習を実施
  2. シードを固定したにも関わらず学習を再現することができない
  3. シードの固定漏れを調査するも見つからない
  4. 予測モデルのモード(evalモード)変更忘れを疑うも問題なし
  5. データの加工の流れから再現しない箇所を特定
  6. 再現しない原因がKMeansのcluster_centers_の誤差ということが判明

という流れでした。
scikit-learnではシードをrandom_stateという引数で設定できるものがあり、KMeansもその引数でシードを固定していたので特定までに時間がかかってしまいました。

誤差検証コード

適当なデータをscikit-learnのmake_blobsで作成し誤差の検証をしてみたいと思います。

sample.py
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
#適当なサンプルデータを作成します
X,Y = make_blobs(random_state=0, n_samples=1000, n_features=2, cluster_std=1.5, centers=4)
#K-meansのインスタンス1つ目
clf1 = KMeans(random_state=0,n_clusters=4)
clf1.fit(X)
#K-meansのインスタンス2つ目
clf2 = KMeans(random_state=0,n_clusters=4)
clf2.fit(X)
#cluster_centersの比較
print((clf1.cluster_centers_ == clf2.cluster_centers_).all())

おそらくFalseが出力されると思います。
Trueとなる場合もありますが複数回実行するとFalseとなります。
ただし、この誤差の影響は極めて限定されています。

誤差の影響について

誤差の影響を調べてみました。
上記の検証でcluster_centers_の値が一致しなかったパラメータをプロットした結果がこちらです。
box-image.png

元のデータのサイズがある程度の大きさのためこの図を見る限りでは全く同じに見えます。
これだけでほぼ影響がないと言えますが一応2つのcluster_centers_の差分を計算してみます。

sample.py
#cluster_centers_の絶対値の差分計算
np.absolute(clf1.cluster_centers_) - np.absolute(clf2.cluster_centers_)


#出力
array([[ 0.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00, -4.4408921e-16],
       [ 0.0000000e+00, -8.8817842e-16],
       [ 0.0000000e+00,  0.0000000e+00]])

このように非常に値が小さいことがわかります。

またこの誤差はマシンイプシロンに関係していることがわかりました。
下記のコードでマシンイプシロンを確認してみます。

sample.py
#マシンイプシロン
np.finfo(float).eps

#出力
2.220446049250313e-16

上記を見ると差分がちょうどマシンイプシロンの倍数になっています。
※差分を表示する方法によっては桁が大きすぎて省略されることがあります。
おそらく非常に小さい値の計算時に誤差が発生し上記のような結果が再現しきれない問題が発生したと思われます。
こちらは四捨五入や切り捨てなどの処理を加えることで回避可能と思われます。

参考

機械学習におけるランダムシードの研究

まとめ

  • 誤差怖い
  • けど厳密な計算をしないなら誤差が発生しないようにけたを落とせば解決!
3
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?