- 論文
- Tweet
- pytorch実装: https://github.com/bharathgs/NALU
概要
- 加減乗除+べき乗が扱えるようなユニットNALU(cf. LSTM)を作った
- 時間計測、文字上の数字を実際の数字にする、画像上で物体数を数える等のタスクで既存手法よりよい結果を得た
- 今までと違い、訓練データの範囲を超えたような数値入力が来てもうまく答えられるものになった
1.1. シンプルな実験による現状の課題意識
- 数字を入れて同じ数字を返すAutoEncoderを考える
- -5から5までTrainして、-20から20まででテストした結果
- どの関数も未知のところでは、完全には学習できていない
2. どう解決するか?
数をシステマティックに扱う機構を作ることを目指す。
数学的にいうと、liner exextrapolation/multiplicative
extrapolationできる理想的なinductive bias(帰納バイアス)を設定すればいいらしい。
まずは、NACというシンプルに加減算ができるUnitをつくり、加減乗除+べき乗演算ができるNALUというUnitを作る
neural accumulator (NAC)
- W は {-1, 0, 1} のどれかしか使わない
- これにより足し算と引き算しかできなくなり変にスケーリングされることがなくなることが保証できる
- これだけだと非連続で微分可能にできないので以下のようにWを定式化する
- W = tanh(Wˆ)*sigmoid(Mˆ)
- ここで*は要素ごとの積
- W = tanh(Wˆ)*sigmoid(Mˆ)
Neural Arithmetic Logic Units (NALU)
- NALUは足し算/引き算を司るNAC(上側の小さい紫)と掛け算/割り算を司るNAC(下側の大きい紫)と比率をコントロールするゲート(オレンジ)のユニットからなる
まとめ
4. 実験
4.1. 単純な関数を再現する実験
- Static Taskはあらかじめ変数が全部見えているケース、Recurrent Taskは連続的に値が入ってくるケース
- 実際はStaticは100次元のVectorが入力、Reccurentは100次元のVectorが入力で10単位時間分の特定部分の合計がaやbになっている
- ReccurentのタスクはExtrapolationは単位時間を10から1000に伸ばすことで実現
- extrapolationの方は訓練データの範囲を超えた数値データが入った場合の結果
- NALUは、割り算を除いてExtrapolationのデータでもパーフェクトの結果を出している
- Reccurentの割り算が特に悪いのは、このモデルは0にかなり近い数値も扱える構造なので、それが分母に入ると値が飛んでしまうのが理由らしい
4.2. MNIST Counting / Addition Task
- 実験の目的: 数値でないInputからのBack Propagationが可能なことを示す
- 訓練は10個のデータで行い、テストは1,10,100,1000で行う
- Counting Taskは各数字が何個ずつあるか数えるタスク (COUNT)
- Addition Taskは各数字を全部足し込むタスク (SUM)
- 足し算と引き算のみ行えるUNITのNACの結果が群を抜いている。NALUもすごく悪いわけではない
- RNN-tanh, RNN-ReLUは訓練データよりも長い系列だけでなく、短い系列の時もうまくいかないことがある
4.3. 言語を数値にする問題
問題の概要は下を見るのが一発。すごい。andで1ぐらい増えているのが面白い。andがあった瞬間に最低1は増えるからだろうか?
実験の結果。LSTMの後linear layerを入れるかNACを入れるかNALUを入れるかの比較
- NACがかなり悪いので、このタスクはかけ算の存在が重要な役割を占めていることがわかる
4.4. Programの実行実験
- プログラムといってもif文、足し算、引き算がある構造の実行
- 0から99までの値で実験して、3桁、4桁の結果でどうなるか実験
4.5. Grid-World環境での時間計測実験
- 実験の目的:強化学習の内部構造として使うことで時間概念を理解しながら学習を行えるのではないか
- Dating Taskと呼ばれるこのタスクは、指定された時間TぴったりにAgent(灰色)が目的地(赤)にたどり着くことを目的としている
- 選べる行動は{UP, DOWN, LEFT, RIGHT, PASS}の5つ
- 報酬は時間ぴったりについたときにMaxで、早く着きすぎると早かった時間分だけマイナス報酬
- 動くことにコストがかかるみたいなことはない模様
- 時間内にたどり着けないケースは取り扱わないことにする
- 実験結果は、5-12ステップでたどり着けるデータで訓練したときのテスト結果
- 両方うまくいっているわけではない
- 13以上になってもT=12の時間にたどり着こうとしてしまってスコアが下がっている
- NACを使った方がそうなるケースが少ない
- 早くたどりつき過ぎてしまって、T=12に着いた方が報酬がマイナスになるポイントからはNACの方はたどり着こうとしなくなる
4.6. MNIST Parity Task
- MNISTの数字2つがあってそれのその2つの数値の偶奇が一致しているか調べるタスク