3
3

More than 3 years have passed since last update.

Chainer と TensorFlow と PyTorch の Dropout 実装メモ

Last updated at Posted at 2019-05-03

背景

  • TensorFlow Lite など, 推論エンジンでは乱数が使えない場合が多い(もしくは IoT 系デバイスだと, メルセンヌツイスターなどの擬似乱数を処理させると重い)
  • 既存の機械学習ライブラリで Dropout をどのように実装しているか調べ, 推論でどう扱うかの判断をしたい

Chainer(v5.4)

Chainer では CPU では numpy.random を使い, GPU では CUDA の random を使っています.

function/link レベルでの seed の設定はありません.

deterministic に評価する場合は forward(backward) 実行前に numpy で seed を設定すればいいのかしら?

Tensorflow r1.13

Tensorflow 自体には Dropout という builtin op はありません.

RandomUniform, Scale, Division(+Floor) に分解されます.

TensorFlow では, graph レベルと op レベルで seed を設定することができます.

TensorFlowで乱数シードを固定する
https://qiita.com/yuyakato/items/9a5d80e6c7c41e9a9d22

これにより, deterministic な実行にさせることができる... はず.

TensorFlow Lite の場合, RandomUniform builtin op はありません.

toco の graph transformation で RandomUniform は定数のテンソルに変換されます(seed が設定されている場合. seed が設定されていないと純粋にランダムになり定数化できないので変換エラーになります)
このとき, 乱数には tf 側の乱数ルーチンを使って, 振る舞いを同一にしています.

自作の tflite コンバーターを作る場合は, 乱数の生成をどうするか考える必要がありますね.

PyTorch 1.3.x

PyTorch の nn.Droput は, いろいろ関数呼びまくりでソースコードを追うのが面倒なのですが, どうも最終的には ATen/native/Dropout.cpp の実装を呼んでいるようです

乱数はベルヌーイ分布を使っていますね.

推論側での対応について

基本は deterministic にして事前計算ですが, 実行時に乱数生成することも考えてみます.

WebGL などの GPU 実行で, 本質的に乱数を生成するのが難しい環境では, ノイズテクスチャを作ったり, 近似の乱数生成関数や, シェーダで評価しやすい乱数生成関数を使うことができそうです.

CPU で組み込み系であれば, xorshift や pcg32 などがよいでしょうか.

推論の呼び出し回数などが, ランタイム側で取得できれば, 呼び出し回数を seed にして事前計算した乱数のテンソルを切り替えるという手もありそうです.

モンテカルロレイトレーシング業界からみた Dropout のふるまい

機械学習でいう Dropout は, モンテカルロレイトレーシング業界からみると russian roulette に似ていますね.

TODO

  • pytorch での実装を調べる(pytorch では dropout での乱数は, 一様分布ではなくベルヌーイ分布を利用している模様)
  • xorshift とかの軽量乱数生成で学習させるとどうなるか調べる
  • Halton 列などの準モンテカルロで学習させるとどうなるか調べる
  • Deterministic(seed 固定)で学習させると, 学習の精度がどうなるか調べる
  • 学習時は non-deterministic で, 推論時だけ deterministic にすると, 推論の精度がどうなるか調べる
3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3