Relational Networkの続きです。RRN(Reccurent Relational Network)です。めっちゃ時間かかりました。とりあえず論文のリンクをhttps://arxiv.org/abs/1711.08028。Tensorflow2での実装がなかったので自分で実装したら、迷路に迷い込みました。コードはここ→https://github.com/rasmusbergpalm/recurrent-relational-networksを参考にしました。
RRNの形式
RNでは
$RN(O)=f_ϕ(\sum_{i,j}g_θ(o_i,o_j))$
としていたものをRRNでは
$m_{ij}^t = f_{\phi}(h_i^t,h_j^t)$
$m_j^t=\sum_im_{ij}^t$
$h_j^t=g_{\theta}(h_j^t,h_j^0,m_j^t)$
のようにします。ここで$i,j$は各ノードの添え字、$t$は時刻の添え字、$h_i^t,h_j^t$はノードi,jの状態ベクトル、$m_{ij}^t$はノードjからノードiへのメッセージベクトル、
この式を見てやっとわかりました。RNもRRNもGraph neural networkの一種だったんですね。$\sum$の理由が分からなかったんですが、エッジのあるノード間が窓になっているようなCNNと思えば、なんとなく納得できました。上記の式以外の形式のGraph neural networkもいくつもあるみたいです。→https://qiita.com/shionhonda/items/d27b8f13f7e9232a4ae5
RNNの実装
まあ、ひどい実装です。直したいところはいくらでもありますが、とりあえずは動いたので載せるといった感じです。
実装したリポジトリ→
参考にしたリポジトリ→https://github.com/rasmusbergpalm/recurrent-relational-networks
RNNの結果
テストデータの正解率だけで言えば、残念ながら20%を少し下回る結果です。ただし、学習している傾向は見て取れるます。下のgifの画像は訓練後のネットワークの各時刻における出力です。数独のルールに基づいて回答を様子が分かります。
正解率が低い理由としては最後の詰めが甘いようです。特に最初に空きマスが多いとなかなか各ノードの状態ベクトルが収束せず、ステップ数不足で間違える用です。
問題点
1.損失をどのように与えたら良いのかが難しい。
数独ではそのタイミングでは明らかに埋められないマスと埋められるマスがあります。あそこのマスが埋まったからこちらのマスも埋められるという思考のステップが必要になります。しかし、今回のプログラムはすべてのステップで全てのマスに対して損失を与えています。絶対に埋められないマスがあるにも関わらず損失を与えるのは理不尽に思えます。もちろんそれでも学習はするのでしょうが、むしろノイズとなってしまって学習の妨げになっている可能性もあると思います。解決策としては強化学習で損失にマスクをかけるとかでしょうか。
2.ルールをハードコーディングしている部分がある
作成したモデルではノードの接続は数独のルール(縦、横、同じ3x3のマスとは違う数字となる)に従った接続になっています。これはドメイン知識を活用しているという見方もできますが、汎化性能を下げる要因になります。実際に解きたいARCでは問題によって、ノードの接続は変化するべきです。一番簡単な実装は全接続ですが、これもノイズを増やす原因になりえると思っています。Attenntionの枠組みを使って接続するノードを制限するといったことも考えています。
3.学習の不安定さ
モデルが完成してから、起きた問題として学習の不安定さがあります。optimizerの学習率を上げすぎるとlossの減少速度は上がるが、学習が進んだあるタイミングでlossがまた上がり始めるという現象が起きるということが何度も試しているうちにわかりました。原因はよく分かりませんが、重みの値が発散している可能性を考えています。そのために重みに制限を持たせて、少しはましになった様子もありますが、それでも学習率を下げないとその現象が起きます。この問題に長い時間、悩まされました。リカレントな構造を持っていることも原因かもしれません。ARCを解くための問題にならなければいいのですが。
4.実行速度
学習にかなり時間がかかります。今回の正解率を達成するためだけでも2日ほど学習時間がかかったはずです。(正確に測定はしていません)多層のMLPを使うのでそれなりの計算量がかかります。計算量の削減も課題でしょう。時間は惜しいです。ハード面で解決するという手もありますが。
終わりに
本当にこんな時間がかかるとは。元論文のコピーをするだけと思ってたのに。まあ、Tensorflow2.x系の勉強にはなりました。それにネットワークが数独を解く動画は本当に思考の様子を見ているようで面白いです。またこれもひどい記事ですね。とりあえず、論理みたいなものを扱う、neural networkはできました。(改善点は驚くほどありますが)次はfewshot leaningやります。それ以外に強化学習とかattentionも必要なきがしてきたし。大変だなあ。