はじめに
以前の足し算ゲームを強化学習で学習できるか?では、単純な足し算ゲームの良い行動を Q-Learningという方法で学習してみました。
今回は同じゲームを Chainerを使って学習させることにします。
ただ、完全に手探りで作ったので、正しくない箇所もまだ残っていそうですが、一応学習できたので投稿しておきます。
お題: 足し算ゲーム by Chainer
前回と同じで、以下のゲームを考えます。
- 状態S: 0~9 の整数
- アクションA: 1~4の整数
- 次状態S': (S + A) % 10
- 報酬R:
- +1: S' == 7
- -100: S' in (5, 9)
実験
実験に使ったソースコードはこちらです。
基本的には、前回のQLearningPlayer
を NNQLearningPlayer
に置き換える形になります。
まだ間違っている箇所もありそうですし、色々わからなかった、ハマった点もあったので少しコードの説明を書いておきます。
ちょうど この頃のコードが混乱の極みだったと思うので、そこからの修正ポイントを説明します。
Q-LearningをChainerで行う部分
ある意味一番重要な部分です。理屈としては下記の部分を参考に作ったつもりです。
DeepじゃないQ-network:Q学習 + 関数近似
ただ、ここの部分はそこまで酷く間違ってはいなかったのかなと思います。
def update_q_table(self, last_state, last_action, cur_state, last_reward):
target_val = last_reward + self.GAMMA * np.max(self.forward(cur_state, volatile=True).data) # target = r + γmaxQ(s', a') の部分
self.optimizer.zero_grads()
q_last = self.forward(last_state) # Q(s, a) の部分
tt = np.copy(q_last.data)
tt[0][last_action] = target_val # a の部分だけ学習するのかな、と思ったのでこうしている。
target = Variable(tt)
loss = 0.5 * (target - q_last) ** 2 # L = 1/2 * (target - Q(s,a))^2 の部分
loss.grad = np.array([[self.ALPHA]], dtype=np.float32) # ここから3行が θ ← θ - αΔL の部分
loss.backward()
self.optimizer.update()
という感じになっています。
(たぶん)失敗してたポイント1: 入力の与え方
k/N から k が出せるのだから、10状態は1つの実数値で良いだろうというのが 浅はかだった と思います。
確かにこの方法でも瞬間的に33〜34のスコア(ベストな行動)を行うのですが、最終的には全然あさっての方にいってしまいます。
k/Nな入力だと意味的にすごく重要な値の境目があるので、重みがすごくデリケートになってしまうのかなと思います。
そこで Model を
self.model = FunctionSet(
l1=F.EmbedID(10, 10),
l2=F.Linear(10, 10),
l3=F.Linear(10, 4),
)
というように F.EmbedID
で入力を誤解しにくいように?しました。
F.EmbedID(in_size, out_size)
というのは in_sizeの種類のID(今回は0~9のnp.int32)を out_sizeのNodeで表現するというものです。 公式チュートリアルの2つ目にも登場しています。
(たぶん)失敗していたポイント2: 報酬の与え方
最初は +1, -100 という報酬を与えていましたが、どうもこれも まずい らしい(たまに Overflowするようだし)。
代わりに 0.01, -1 とScaleしてやると何やら上手くいきました。
とりあえず、そういうものだと思っておくことにしました。。
学習のパラメータ
- α:「理想とのズレ」をどの程度フィードバックするか
- E_GREEDY: 学習フェーズで、どの程度ランダムに行動を決めるか
というパラメータは、「α大きめ(0.05 -> 0.1)」「E_GREEDY大きめ(0.1 -> 0.3)」とする方が今回は良かった(学習が速かった)です。どの辺りがベストなのかはよくわかってないですが。
感覚的にはE_GREEDYは大きいほうがNNには向いているような気がします。
「あるときはかなりランダム」「あるときはランダムなし」を明確に分けるとなんか見てて楽しいです。
必要の無かった volatile
学習できないのは何かよくわからない副作用のせいなのか、、、と思っていて、公式チュートリアルの2つ目にあった valatileをつけたりしましたが、たぶんこれは不要でした(なくてもちゃんと学習しているし)。
これって、recurrent network を使う時におそらく有用な仕組みなのかな、、、と今は思っています(が、確信はない・・・)。backwardする予定がないときや、余分な computation history
を残したくない時に使う、と(computation historyが何を表すのかがやっぱりよくわかってないんですが h
と関係ある何かなのかな)。
結果
何度かやってみると、最初はいろいろさまよっているものの、 100〜150万回くらい学習するとだいたいベストな行動を取り続けるようになります。 Q-Tableを使っている時に比べて、なかなか安定しないのが面白いですね。
一応、最後は収束したのでこんな感じでOKなのかなと思っているのですが、
たぶん、何か実装か考え方が間違っている可能性は否めません。。
さいごに
次はもうちょっと(視覚的に)面白いゲームに挑戦してみようと思います。
Pythonだとcursesとか使って遊べますしね。