Edited at

ChainerでやってみるActor-Critic(Deep DPG) - ドラフト編

More than 1 year has passed since last update.

強化学習といえばDeep Q learningみたいなのりで、DQNがもてはやされていますが、AlphaGoとかロボットの機械学習では数年前くらいからActor-Criticに移行してきているように見えます。

その一方でパワーポイントに飼い慣らされた漫画お脳には論文なんて読んでも面白みがない感じの毎日なのに、一方に解説が出てくる気配が感じられません。ということで、鳩山イニシアチブが如く、恥を忍んで今の理解をざっくり紙芝居にします。


復習


強化学習

だいたい世の強化学習ってこんな絵で始まります。

image.jpeg

これをロボット制御に使う場合は実際の出力は、動作指令値であって出力ではなく、こんな感じ。

image.jpeg

神の設計というか、リワードの設計がいろいろ面倒で、OpenAIとかdeep mindとかからの共同論文にも言及ありましたね。

ここでAgentが獲得を目指すのはQ値で評価される値で長期的にみて報酬rの合計値が最大化される値でしたね。

Q = r_0 + \gamma \cdot r_1 + \gamma ^2 \cdot r_2 + \cdots 


Deep Q network

強化学習でq学習を使うことにして、更にq関数をディープなニューラルネットワークで近似することにすると、

image.jpeg

関数はこんな感じでした。上が伝統的q関数をそのままニューラルネットワークにしたので、ニューラルネットで表現するのに不便なので下がdeep q networkではセオリーなq関数。

image.jpeg

その時々の状態(state)に対してQ値の最大が期待される行動をとっていくようにします。学習の過程でQ値を予測する関数をどんどん正確にしていき、行動を最適化します。


Actor-critic DDPG (Deep Deterministic Policy Gradient)

Q関数を求めるところと状態に応じた行動を決定する部分を分けたのがActor-Criticという強化学習方法で、調べれば調べるほど色んなタイプがあることがわかります(いや、本当に迷子になりました)

古典的なTD学習では、

Q(s,a) = V(s) + A(s,a)

という表現に置き換えてやると、

A(s,a) \approx r + \gamma V(s_{t+1}) -V(s_t) = TDError

であることから、q関数のところを価値関数vと政策関数πに置き換えて、

image.jpeg

としてしまっているんだって!

これらの価値関数を線形関数で近似したのがAction value Actor-Critic型Policy Gradientによる連続値動作の強化学習です。

さて、DDPGでは、これを改めてQ関数で表現し、それをNNで表現する…というのがだいたいの骨子です。

DDPG.png

ここでQ値関数を

DDPG2.png

政策関数を

DDPG3.png

にして、かつ関数がディープなニューラルネットワークなのがDDPGという感じ。


実装

いつものChainerでやってみるDeep Q Learning - 立ち上げ編 - Qiitaで試験。一度作ると使いまわせていいですね。

005.gif


Q関数

エージェントは左右輪のアクションを取るので、ステート数+2のネットワーク。

Q関数は教師信号

y = R + \gamma \cdot Q(s', \pi(s'))

をターゲットに学習する

class Q(Chain):

def __init__(self, state_dim = STATE_DIM):
super(Q, self).__init__(
l1=F.Linear(state_dim + 2, 20),
l2=F.Linear(20, 20),
v_value=F.Linear(20, 1)
)
def __call__(self, x, t):
return F.mean_squared_error(self.predict(x, train=True), t)

def predict(self, x, train = False):
h1 = F.leaky_relu(self.l1(x))
h2 = F.leaky_relu(self.l2(h1))
y = self.v_value(h2)
return y


政策関数

政策関数については,Chainerではロス関数というより勾配を直接定義して学習を進める必要がある様子。

\frac{\partial Loss_\pi}{\partial \theta} = \frac{\partial Q}{\partial a} \frac{\partial \pi}{\partial \theta}

これについては、

\frac{\partial Loss_\pi}{\partial \theta} = \frac{\partial Q(s,\pi(s;\theta))}{\partial \theta}

とも表せることからQとπをまとめて定義しちゃう方がいいかなとも思ったのですが、学習フェーズでノイズを負荷したaの処置がよくわからなかったので別個に定義しました。

class PolicyNetwork(Chain):

def __init__(self, state_dim = STATE_DIM):
super(PolicyNetwork, self).__init__(
l1=F.Linear(state_dim, 20),
p_value=F.Linear(20, 2)
)
def __call__(self, x, t):
return F.mean_squared_error(self.predict(x, train=True), t)

def predict(self, x, train = False):
h1 = F.leaky_relu(self.l1(x))
y = self.p_value(h1)
return y


学習

学習のところはDQNと同様にexperience replayです。

ポリシーの学習については、これでいいのかなと思いつつ、

\frac{\partial Q}{\partial a}

を放り込んでます!

    def update_model(self, r, state, a):

print a, r
if len(self.eMem)<self.batch_num:
return

memsize = self.eMem.shape[0]
batch_index = np.random.permutation(memsize)[:self.batch_num]
batch = np.array(self.eMem[batch_index], dtype=np.float32).reshape(self.batch_num, -1)

p_seq = batch[:, 0:STATE_DIM]
r = batch[:, STATE_DIM]
a = batch[:, STATE_DIM:STATE_DIM+2]
new_seq= batch[:,(STATE_DIM+2):(STATE_DIM*2+2)]

x = Variable( np.hstack((p_seq,a)).astype(np.float32))
na = self.modelP.predict(Variable(new_seq)).data
tmpQ = self.modelQ.predict(Variable(np.hstack((new_seq,na)))).data.copy()
targetsQ = r.reshape((self.batch_num,-1)) + self.gamma * tmpQ

t = Variable(np.array(targetsQ, dtype=np.float32).reshape((self.batch_num,-1)))

# Q関数更新
self.modelQ.zerograds()
loss=self.modelQ(x ,t)
self.loss = loss.data
loss.backward()
self.optimizerQ.update()

# Policy更新
s = Variable(p_seq)
self.modelP.zerograds()
y = self.modelP.predict(s)
y.grad = x.grad[:,-2:]
y.backward()
self.optimizerP.update()


結果

下記実行していただいて目で見ていただくのがいいと思います。

リワードの微調整をしてやると挙動が結構変わります。

一度スピンを始めると中々抜け出せないのでそういう時は再起動してやってください。

考え方の違うところがあったら教えてください。添削していただける下心で公開してます!

環境はこちら

* Python 2.7

* Chainer 1.7.2

* numpy 1.11.0

* wxPython


ソース

汚いっすが貼っておきますね。そろそろwxPythonじゃなくてkivyにしようかな。

# -*- coding: utf-8 -*-

import wx
import wx.lib
import wx.lib.plot as plot

import math
import random as rnd

from chainer import cuda, optimizers, FunctionSet, Variable, Chain
import chainer.functions as F

import numpy as np
import copy

# import pickle

np.random.seed(1023)

# Steps looking back
STATE_NUM = 2

# State
STATE_NUM = 2
NUM_EYES = 9
STATE_DIM = NUM_EYES * 3 * 2

class SState(object):
def __init__(self):
self.seq = np.ones((STATE_NUM, NUM_EYES*3), dtype=np.float32)

def push_s(self, state):
self.seq[1:STATE_NUM] = self.seq[0:STATE_NUM-1]
self.seq[0] = state

def fill_s(self, state):
for i in range(0, STATE_NUM):
self.seq[i] = state

class Q(Chain):
def __init__(self, state_dim = STATE_DIM):
super(Q, self).__init__(
l1=F.Linear(state_dim + 2, 20),
l2=F.Linear(20, 20),
v_value=F.Linear(20, 1)
)
def __call__(self, x, t):
return F.mean_squared_error(self.predict(x, train=True), t)

def predict(self, x, train = False):
h1 = F.leaky_relu(self.l1(x))
h2 = F.leaky_relu(self.l2(h1))
y = self.v_value(h2)
return y

class PolicyNetwork(Chain):
def __init__(self, state_dim = STATE_DIM):
super(PolicyNetwork, self).__init__(
l1=F.Linear(state_dim, 20),
l2=F.Linear(20, 20),
p_value=F.Linear(20, 2)
)
def __call__(self, x, t):
return F.mean_squared_error(self.predict(x, train=True), t)

def predict(self, x, train = False):
h1 = F.leaky_relu(self.l1(x))
h2 = F.leaky_relu(self.l2(h1))
y = self.p_value(h1)
return y

class Walls(object):
def __init__(self, x0, y0, x1, y1):
self.xList = [x0, x1]
self.yList = [y0, y1]
self.P_color = wx.Colour(50,50,50)

def addPoint(self, x, y):
self.xList.append(x)
self.yList.append(y)

def Draw(self,dc):
dc.SetPen(wx.Pen(self.P_color))
for i in range(0, len(self.xList)-1):
dc.DrawLine(self.xList[i], self.yList[i], self.xList[i+1],self.yList[i+1])

def IntersectLine(self, p0, v0, i):
dp = [p0[0] - self.xList[i], p0[1] - self.yList[i]]
v1 = [self.xList[i+1] - self.xList[i], self.yList[i+1] - self.yList[i]]

denom = float(v1[1]*v0[0] - v1[0]*v0[1])
if denom == 0.0:
return [False, 1.0]

ua = (v1[0] * dp[1] - v1[1] * dp[0])/denom
ub = (v0[0]*dp[1] - v0[1] * dp[0])/denom

if 0 < ua and ua< 1.0 and 0 < ub and ub < 1.0:
return [True, ua]

return [False, 1.0]

def IntersectLines(self, p0, v0):
tmpt = 1.0
tmpf = False
for i in range(0, len(self.xList)-1):
f,t = self.IntersectLine( p0, v0, i)
if f:
tmpt = min(tmpt, t)
tmpf = True

return [tmpf, tmpt]

class Ball(object):
def __init__(self, x, y, color, property = 0):
self.pos_x = x
self.pos_y = y
self.rad = 10

self.property = property

self.B_color = color
self.P_color = wx.Colour(50,50,50)

def Draw(self, dc):
dc.SetPen(wx.Pen(self.P_color))
dc.SetBrush(wx.Brush(self.B_color))
dc.DrawCircle(self.pos_x, self.pos_y, self.rad)

def SetPos(self, x, y):
self.pos_x = x
self.pos_y = y

def IntersectBall(self, p0, v0):
# StackOverflow:Circle line-segment collision detection algorithm?
# http://goo.gl/dk0yO1

o = [-self.pos_x + p0[0], -self.pos_y + p0[1]]

a = v0[0] ** 2 + v0[1] **2
b = 2 * (o[0]*v0[0]+o[1]*v0[1])
c = o[0] ** 2 + o[1] **2 - self.rad ** 2

discriminant = float(b * b - 4 * a * c)

if discriminant < 0:
return [False, 1.0]

discriminant = discriminant ** 0.5

t1 = (- b - discriminant)/(2*a)
t2 = (- b + discriminant)/(2*a)

if t1 >= 0 and t1 <= 1.0:
return [True, t1]

if t2 >= 0 and t2 <= 1.0:
return [True, t2]

return [False, 1.0]

class EYE(object):
def __init__(self, i):
self.OffSetAngle = - math.pi/3 + i * math.pi*2/3/NUM_EYES
self.SightDistance = 0
self.obj = -1
self.FOV = 130.0
self.resetSightDistance()

def resetSightDistance(self):
self.SightDistance = self.FOV
self.obj = -1

class Agent(Ball):
def __init__(self, panel, x, y, alpha = 1.8):
super(Agent, self).__init__(
x, y, wx.Colour(112,146,190)
)
self.dir_Angle = math.pi/4
self.speed = 5

self.pos_x_max, self.pos_y_max = panel.GetSize()
self.pos_y_max = 480

self.eyes = [ EYE(i) for i in range(0, NUM_EYES)]

self.prevActions = np.zeros_like([])

# Actor Critic Model
self.theta = np.random.rand(2, STATE_DIM)
self.sigma = np.array([[0.2],[0.2]])#np.random.rand(2, 1)
self.w = np.random.rand(1, STATE_DIM)

# self.model
self.modelQ = Q()
self.optimizerQ = optimizers.Adam()
self.optimizerQ.setup(self.modelQ)

self.modelP = PolicyNetwork()
self.optimizerP = optimizers.Adam()
self.optimizerP.setup(self.modelP)

# experience Memory
self.eMem = np.array([],dtype = np.float32)
self.memPos = 0
self.memSize = 30000

self.batch_num = 30
self.loss = 0.0

self.alpha = alpha
self.gamma = 0.95

self.State = SState()
self.prevState = np.ones((1,STATE_DIM))

def UpdateState(self):
s = np.ones((1, NUM_EYES * 3),dtype=np.float32)
for i in range(0, NUM_EYES):
if self.eyes[i].obj != -1:
s[0, i * 3 + self.eyes[i].obj] = self.eyes[i].SightDistance / self.eyes[i].FOV
self.State.push_s(s)

def Draw(self, dc):
dc.SetPen(wx.Pen(self.P_color))
dc.SetBrush(wx.Brush(self.B_color))
for e in self.eyes:
if e.obj == 1:
dc.SetPen(wx.Pen(wx.Colour(112,173,71)))
elif e.obj == 2:
dc.SetPen(wx.Pen(wx.Colour(237,125,49)))
else:
dc.SetPen(wx.Pen(self.P_color))

dc.DrawLine(self.pos_x, self.pos_y,
self.pos_x + e.SightDistance*math.cos(self.dir_Angle + e.OffSetAngle),
self.pos_y - e.SightDistance*math.sin(self.dir_Angle + e.OffSetAngle))

super(Agent, self).Draw(dc)

def get_action(self, state):
x = Variable(state.reshape((1, -1)).astype(np.float32))
y = self.modelP.predict(x)

tmp = [max(min(y.data[0][0],0.9),-0.9),
max(min(y.data[0][1],0.9),-0.9)]

x = [tmp[0]+np.random.normal(0., self.alpha),
tmp[1]+np.random.normal(0., self.alpha)]
return x

def reduce_alpha(self):
self.alpha -= 1.0/2000
self.alpha = max(0.02, self.alpha)

def Value(self, state):
x = Variable(state.reshape((1, -1)).astype(np.float32))
return self.modelQ.predict(x).data[0]

def update_model(self, r, state, a):
print a, r
if len(self.eMem)<self.batch_num:
return

memsize = self.eMem.shape[0]
batch_index = np.random.permutation(memsize)[:self.batch_num]
batch = np.array(self.eMem[batch_index], dtype=np.float32).reshape(self.batch_num, -1)

p_seq = batch[:, 0:STATE_DIM]
r = batch[:, STATE_DIM]
a = batch[:, STATE_DIM:STATE_DIM+2]
new_seq= batch[:,(STATE_DIM+2):(STATE_DIM*2+2)]

x = Variable( np.hstack((p_seq,a)).astype(np.float32))
na = self.modelP.predict(Variable(new_seq)).data
tmpQ = self.modelQ.predict(Variable(np.hstack((new_seq,na)))).data.copy()
targetsQ = r.reshape((self.batch_num,-1)) + self.gamma * tmpQ

t = Variable(np.array(targetsQ, dtype=np.float32).reshape((self.batch_num,-1)))

# Q関数更新
self.modelQ.zerograds()
loss=self.modelQ(x ,t)
self.loss = loss.data
loss.backward()
self.optimizerQ.update()

# Policy更新
s = Variable(p_seq)
self.modelP.zerograds()
y = self.modelP.predict(s)
y.grad = x.grad[:,-2:]
y.backward()
self.optimizerP.update()

def experience(self,x):
if self.eMem.shape[0] > self.memSize:
self.eMem[int(self.memPos%self.memSize)] = x
self.memPos+=1
elif self.eMem.shape[0] == 0:
self.eMem = x
else:
self.eMem = np.vstack( (self.eMem, x) )

def Move(self, WallsList):
flag = False
dp = [ self.speed * math.cos(self.dir_Angle),
-self.speed * math.sin(self.dir_Angle)]

for w in WallsList:
if w.IntersectLines([self.pos_x, self.pos_y], dp)[0]:
dp = [0.0, 0.0]
flag = True

self.pos_x += dp[0]
self.pos_y += dp[1]

self.pos_x = max(0, min(self.pos_x, self.pos_x_max))
self.pos_y = max(0, min(self.pos_y, self.pos_y_max))

reward = (dp[0]**2+dp[1]**2)**0.5

return flag, reward

def HitBall(self, b):
if ((b.pos_x - self.pos_x)**2+(b.pos_y - self.pos_y)**2)**0.5 < (self.rad + b.rad):
return True
return False

class World(wx.Frame):
def __init__(self, parent=None, id=-1, title=None):
wx.Frame.__init__(self, parent, id, title)
self.panel = wx.Panel(self, size=(640, 640))
self.panel.SetBackgroundColour('WHITE')
self.Fit()

self.A = Agent(self.panel, 350, 150 )

self.greenB = [Ball(rnd.randint(40, 600),rnd.randint(40, 440),
wx.Colour(112,173,71), property = 1) for i in range(0, 15)]
self.redB = [Ball(rnd.randint(40, 600),rnd.randint(40, 440),
wx.Colour(237,125,49), property = 2) for i in range(0, 10)]

# OutrBox
self.Box = Walls(640, 480, 0, 480)
self.Box.addPoint(0,0)
self.Box.addPoint(640,0)
self.Box.addPoint(640,480)

# Wall in the world
#self.WallA = Walls(96, 90, 256, 90)
#self.WallA.addPoint(256, 390)
#self.WallA.addPoint(96,390)

self.Bind(wx.EVT_CLOSE, self.CloseWindow)

self.cdc = wx.ClientDC(self.panel)
w, h = self.panel.GetSize()
self.bmp = wx.EmptyBitmap(w,h)

self.timer = wx.Timer(self)
self.Bind(wx.EVT_TIMER, self.OnTimer)
self.timer.Start(20)

# Plot
# https://www.daniweb.com/programming/software-development/code/216913/using-wxpython-for-plotting
#self.plotter = plot.PlotCanvas(self, pos=(0,480),size=(400,240))
#self.plotter.SetEnableZoom(True)
#self.plotter.SetEnableLegend(True)
#self.plotter.SetFontSizeLegend(20)

def CloseWindow(self, event):
# self.timer.Stop()
wx.Exit()

def OnTimer(self, event):
# Update States
self.A.UpdateState()
state = self.A.State.seq.reshape(1,-1)
action = self.A.get_action(state)

# Action Step
action[0] = max(min(action[0], 1.0), -1.0)
action[1] = max(min(action[1], 1.0), -1.0)

self.A.speed = (action[0] + action[1])/2.0 * 5.0
tmp = math.atan(max(min((action[0] - action[1])/ 2.0 / 2.5, 2.0), -2.0))
self.A.dir_Angle += tmp
self.A.dir_Angle = ( self.A.dir_Angle + np.pi) % (2 * np.pi ) - np.pi

flag,rrr = self.A.Move([self.Box])

digestion_reward = 0.0
for g in self.greenB:
if self.A.HitBall(g):
g.SetPos(rnd.randint(40, 600),rnd.randint(40, 440))
digestion_reward -= 6.0
for r in self.redB:
if self.A.HitBall(r):
r.SetPos(rnd.randint(40, 600),rnd.randint(40, 440))
digestion_reward += 5.0

# Reward
proximity_reward = 0.0
for e in self.A.eyes:
proximity_reward += float(e.SightDistance)/e.FOV if e.obj == 0 else 1.0
proximity_reward /= NUM_EYES
proximity_reward = min(1.0, proximity_reward*2)

forward_reward = 0.0
if(self.A.speed > 2.0 and proximity_reward > 0.75):
forward_reward = 0.1
elif(self.A.speed < 1.0):
foward_reward = -0.2
if (self.A.speed > 0.0 and tmp > -np.pi/8 and tmp < np.pi/8):
forward_reward += 0.1

if flag:
wall_punish = -1.0
else:
wall_punish = 0.0
reward = proximity_reward + forward_reward + digestion_reward + wall_punish+rrr/2.0
print ("reward:%.2f %.2f %.2f %.2f")%(proximity_reward,forward_reward
,digestion_reward, wall_punish+rrr/2.0)

# Learning Step
self.A.experience(np.hstack([
self.A.prevState,
np.array([reward]).reshape(1,-1),
np.array([action]).reshape(1,-1),
state
]))
self.A.update_model(reward, state, action)
self.A.prevState = state.copy()
self.A.reduce_alpha()

# Graphics Update
for e in self.A.eyes:
e.resetSightDistance()
p = [self.A.pos_x, self.A.pos_y]
s = math.sin(self.A.dir_Angle + e.OffSetAngle)
c = math.cos(self.A.dir_Angle + e.OffSetAngle)

for g in self.greenB:
f, t = g.IntersectBall(p, [e.SightDistance * c, - e.SightDistance * s])
if f:
e.SightDistance *= t
e.obj = g.property

for r in self.redB:
f, t = r.IntersectBall(p, [e.SightDistance * c, - e.SightDistance * s])
if f:
e.SightDistance *= t
e.obj = r.property

for w in [self.Box]:
f, t = w.IntersectLines(p, [e.SightDistance * c, - e.SightDistance * s])
if f:
e.SightDistance *= t
e.obj = 0

self.bdc = wx.BufferedDC(self.cdc, self.bmp)
self.gcdc = wx.GCDC(self.bdc)
self.gcdc.Clear()

self.gcdc.SetPen(wx.Pen('white'))
self.gcdc.SetBrush(wx.Brush('white'))
self.gcdc.DrawRectangle(0,0,640,640)

self.A.Draw(self.gcdc)
for g in self.greenB:
g.Draw(self.gcdc)
for r in self.redB:
r.Draw(self.gcdc)

self.Box.Draw(self.gcdc)
#self.WallA.Draw(self.gcdc)

#line = plot.PolyLine(ext.extract(), colour='blue', width=1, legend=filename)
#graph = plot.PlotGraphics([line], 'Wave', 't', 'f(t)')
#self.plotter.Draw(graph, xAxis=(0,len(ext.wav)), yAxis=(ext.wav.minvalue, ext.wav.maxvalue))

if __name__ == '__main__':
app = wx.PySimpleApp()
w = World(title='RL Test')
w.Center()
w.Show()
app.MainLoop()


参考

http://www.slideshare.net/carpedm20/continuous-control-with-deep-reinforcement-learning-ddpg