LoginSignup
7

More than 5 years have passed since last update.

Dueling Networkを実装する(1)

Last updated at Posted at 2016-10-15

環境

GPU GTX1070
ubuntu 14.04
chainer 1.14.0
など

はじめに

最近numpyを復習しているが、その過程で昔実装したDueling Networkが間違っていたのではないかという疑念が湧いてきた。そこで改めてDeuling Network固有の層を確認し、問題があればコードを修正する。

Z. WangらのDueling Networkの論文は以下。
https://arxiv.org/pdf/1511.06581v3.pdf
この解説を簡潔にまとめています。
http://www.slideshare.net/ssuser07aa33/introduction-to-dueling-network

DQNからDueling Networkへ

今回実装するDueling NetworkはMnihらのDQN論文(2015, nature)に対して出力層を変更しただけのものとする。つまりDouble DQN、あるいはPriorized Replayなど論文中の他のテクニックは実装しない。

実装部分の概要

以下の図はMnih(2015)らのDQNとDueling Networkを比較したものである。
qiita_DN01.png
DQNはconv層3つのあと、全結合層を経て出力のQ値へとつながっている。一方でDueling Networkではconv3層のあと2つの全結合に分かれ、一方がVの算出、もう一方がAの算出となっている。最後にこれらからQ値を求めている。この詳細が以下の図である。
qiita_DN02.png
VおよびAを求めるまではlinearにReLuなので問題ない。最後のQ値の層だけ既存のconnectionを修正して新たに作る必要がある。前回と同様にforward関数、backward関数などを修正することで図中の数式を実現する。

bilinear.pyを確認する。

今回は2つの流れを結合させるので、bilinear.pyを修正することになる。そこでchainer/functions/connection/bilinear.pyの適当なところに

print 'W.shape',
print W.shape

などとプリントアウトしたところ、以下の図ような仕組みになっているみたい。bilinear_for.png
bilinear_back.png
今回はforwardもbackwardも演算がbilinearのそれと全く違うので、重要なのは渡されるデータのサイズくらいだろう。特にforward()で重みWを使用しない。しかし戻り値のvariableオブジェクトにWとかgWとかがないとエラーが出るかも知れないので、適当な値を入れておく。

以下、その(2)に続く。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7