101
96

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ChainerでやってみるActor-Critic(Deep DPG) - ドラフト編

Last updated at Posted at 2016-08-07

強化学習といえばDeep Q learningみたいなのりで、DQNがもてはやされていますが、AlphaGoとかロボットの機械学習では数年前くらいからActor-Criticに移行してきているように見えます。

その一方でパワーポイントに飼い慣らされた漫画お脳には論文なんて読んでも面白みがない感じの毎日なのに、一方に解説が出てくる気配が感じられません。ということで、鳩山イニシアチブが如く、恥を忍んで今の理解をざっくり紙芝居にします。

復習

強化学習

だいたい世の強化学習ってこんな絵で始まります。
image.jpeg

これをロボット制御に使う場合は実際の出力は、動作指令値であって出力ではなく、こんな感じ。
image.jpeg

神の設計というか、リワードの設計がいろいろ面倒で、OpenAIとかdeep mindとかからの共同論文にも言及ありましたね。

ここでAgentが獲得を目指すのはQ値で評価される値で長期的にみて報酬rの合計値が最大化される値でしたね。

Q = r_0 + \gamma \cdot r_1 + \gamma ^2 \cdot r_2 + \cdots 

Deep Q network

強化学習でq学習を使うことにして、更にq関数をディープなニューラルネットワークで近似することにすると、
image.jpeg

関数はこんな感じでした。上が伝統的q関数をそのままニューラルネットワークにしたので、ニューラルネットで表現するのに不便なので下がdeep q networkではセオリーなq関数。
image.jpeg

その時々の状態(state)に対してQ値の最大が期待される行動をとっていくようにします。学習の過程でQ値を予測する関数をどんどん正確にしていき、行動を最適化します。

Actor-critic DDPG (Deep Deterministic Policy Gradient)

Q関数を求めるところと状態に応じた行動を決定する部分を分けたのがActor-Criticという強化学習方法で、調べれば調べるほど色んなタイプがあることがわかります(いや、本当に迷子になりました)

古典的なTD学習では、

Q(s,a) = V(s) + A(s,a)

という表現に置き換えてやると、

A(s,a) \approx r + \gamma V(s_{t+1}) -V(s_t) = TDError

であることから、q関数のところを価値関数vと政策関数πに置き換えて、
image.jpeg

としてしまっているんだって!

これらの価値関数を線形関数で近似したのがAction value Actor-Critic型Policy Gradientによる連続値動作の強化学習です。

さて、DDPGでは、これを改めてQ関数で表現し、それをNNで表現する…というのがだいたいの骨子です。
DDPG.png

ここでQ値関数を
DDPG2.png

政策関数を
DDPG3.png

にして、かつ関数がディープなニューラルネットワークなのがDDPGという感じ。

実装

いつものChainerでやってみるDeep Q Learning - 立ち上げ編 - Qiitaで試験。一度作ると使いまわせていいですね。

005.gif

Q関数

エージェントは左右輪のアクションを取るので、ステート数+2のネットワーク。
Q関数は教師信号

y = R + \gamma \cdot Q(s', \pi(s'))

をターゲットに学習する

class Q(Chain):
    def __init__(self, state_dim = STATE_DIM):
        super(Q, self).__init__(
            l1=F.Linear(state_dim + 2, 20),
            l2=F.Linear(20, 20),
            v_value=F.Linear(20, 1)
        )
    def __call__(self, x, t):
        return F.mean_squared_error(self.predict(x, train=True), t)

    def predict(self, x, train = False):
        h1 = F.leaky_relu(self.l1(x))
        h2 = F.leaky_relu(self.l2(h1))
        y = self.v_value(h2)
        return y    

政策関数

政策関数については,Chainerではロス関数というより勾配を直接定義して学習を進める必要がある様子。

\frac{\partial Loss_\pi}{\partial \theta} = \frac{\partial Q}{\partial a} \frac{\partial \pi}{\partial \theta}

これについては、

\frac{\partial Loss_\pi}{\partial \theta} = \frac{\partial Q(s,\pi(s;\theta))}{\partial \theta}

とも表せることからQとπをまとめて定義しちゃう方がいいかなとも思ったのですが、学習フェーズでノイズを負荷したaの処置がよくわからなかったので別個に定義しました。

class PolicyNetwork(Chain):
    def __init__(self, state_dim = STATE_DIM):
        super(PolicyNetwork, self).__init__(
            l1=F.Linear(state_dim, 20),
            p_value=F.Linear(20, 2)
        )
    def __call__(self, x, t):
        return F.mean_squared_error(self.predict(x, train=True), t)

    def predict(self, x, train = False):
        h1 = F.leaky_relu(self.l1(x))
        y = self.p_value(h1)
        return y      

学習

学習のところはDQNと同様にexperience replayです。
ポリシーの学習については、これでいいのかなと思いつつ、

\frac{\partial Q}{\partial a}

を放り込んでます!

    def update_model(self, r, state, a):
        print a, r        
        if len(self.eMem)<self.batch_num:
            return

        memsize     = self.eMem.shape[0]
        batch_index = np.random.permutation(memsize)[:self.batch_num]
        batch       = np.array(self.eMem[batch_index], dtype=np.float32).reshape(self.batch_num, -1)

        p_seq = batch[:, 0:STATE_DIM]
        r = batch[:, STATE_DIM]
        a = batch[:, STATE_DIM:STATE_DIM+2]
        new_seq= batch[:,(STATE_DIM+2):(STATE_DIM*2+2)]        

        x = Variable( np.hstack((p_seq,a)).astype(np.float32))
        na = self.modelP.predict(Variable(new_seq)).data
        tmpQ = self.modelQ.predict(Variable(np.hstack((new_seq,na)))).data.copy()  
        targetsQ = r.reshape((self.batch_num,-1)) + self.gamma * tmpQ

        t = Variable(np.array(targetsQ, dtype=np.float32).reshape((self.batch_num,-1)))
        
        # Q関数更新
        self.modelQ.zerograds()
        loss=self.modelQ(x ,t)
        self.loss = loss.data
        loss.backward()
        self.optimizerQ.update()
        
        # Policy更新
        s = Variable(p_seq)
        self.modelP.zerograds()
        y = self.modelP.predict(s)
        y.grad = x.grad[:,-2:]
        y.backward()
        self.optimizerP.update()        

結果

下記実行していただいて目で見ていただくのがいいと思います。

リワードの微調整をしてやると挙動が結構変わります。
一度スピンを始めると中々抜け出せないのでそういう時は再起動してやってください。

考え方の違うところがあったら教えてください。添削していただける下心で公開してます!

環境はこちら

  • Python 2.7
  • Chainer 1.7.2
  • numpy 1.11.0
  • wxPython

ソース

汚いっすが貼っておきますね。そろそろwxPythonじゃなくてkivyにしようかな。

# -*- coding: utf-8 -*-
import wx
import wx.lib
import wx.lib.plot as plot

import math
import random as rnd

from chainer import cuda, optimizers, FunctionSet, Variable, Chain
import chainer.functions as F

import numpy as np
import copy

# import pickle

np.random.seed(1023)

# Steps looking back
STATE_NUM = 2

# State
STATE_NUM = 2
NUM_EYES  = 9
STATE_DIM = NUM_EYES * 3 * 2

class SState(object):
    def __init__(self):
        self.seq = np.ones((STATE_NUM, NUM_EYES*3), dtype=np.float32)

    def push_s(self, state):
        self.seq[1:STATE_NUM] = self.seq[0:STATE_NUM-1]
        self.seq[0] = state

    def fill_s(self, state):
        for i in range(0, STATE_NUM):
            self.seq[i] = state

class Q(Chain):
    def __init__(self, state_dim = STATE_DIM):
        super(Q, self).__init__(
            l1=F.Linear(state_dim + 2, 20),
            l2=F.Linear(20, 20),
            v_value=F.Linear(20, 1)
        )
    def __call__(self, x, t):
        return F.mean_squared_error(self.predict(x, train=True), t)

    def predict(self, x, train = False):
        h1 = F.leaky_relu(self.l1(x))
        h2 = F.leaky_relu(self.l2(h1))
        y = self.v_value(h2)
        return y            

class PolicyNetwork(Chain):
    def __init__(self, state_dim = STATE_DIM):
        super(PolicyNetwork, self).__init__(
            l1=F.Linear(state_dim, 20),
            l2=F.Linear(20, 20),
            p_value=F.Linear(20, 2)
        )
    def __call__(self, x, t):
        return F.mean_squared_error(self.predict(x, train=True), t)

    def predict(self, x, train = False):
        h1 = F.leaky_relu(self.l1(x))
        h2 = F.leaky_relu(self.l2(h1))
        y = self.p_value(h1)
        return y        
            
class Walls(object):
    def __init__(self, x0, y0, x1, y1):
        self.xList = [x0, x1]
        self.yList = [y0, y1]
        self.P_color = wx.Colour(50,50,50)

    def addPoint(self, x, y):
        self.xList.append(x)
        self.yList.append(y)

    def Draw(self,dc):
        dc.SetPen(wx.Pen(self.P_color))
        for i in range(0, len(self.xList)-1):
            dc.DrawLine(self.xList[i], self.yList[i], self.xList[i+1],self.yList[i+1])

    def IntersectLine(self, p0, v0, i):
        dp = [p0[0] - self.xList[i], p0[1] - self.yList[i]]
        v1 = [self.xList[i+1] - self.xList[i], self.yList[i+1] - self.yList[i]]

        denom = float(v1[1]*v0[0] - v1[0]*v0[1])
        if denom == 0.0:
            return [False, 1.0]

        ua = (v1[0] * dp[1] - v1[1] * dp[0])/denom
        ub = (v0[0]*dp[1] - v0[1] * dp[0])/denom

        if 0 < ua and ua< 1.0 and 0 < ub and ub < 1.0:
            return [True, ua]

        return [False, 1.0]

    def IntersectLines(self, p0, v0):
        tmpt = 1.0
        tmpf = False
        for i in range(0, len(self.xList)-1):
            f,t = self.IntersectLine( p0, v0, i)
            if f:
                tmpt = min(tmpt, t)
                tmpf = True

        return [tmpf, tmpt]

class Ball(object):
    def __init__(self, x, y, color, property = 0):
        self.pos_x = x
        self.pos_y = y
        self.rad = 10 

        self.property = property

        self.B_color = color
        self.P_color = wx.Colour(50,50,50)

    def Draw(self, dc):
        dc.SetPen(wx.Pen(self.P_color))
        dc.SetBrush(wx.Brush(self.B_color))
        dc.DrawCircle(self.pos_x, self.pos_y, self.rad)

    def SetPos(self, x, y):
        self.pos_x = x
        self.pos_y = y

    def IntersectBall(self, p0, v0):
        # StackOverflow:Circle line-segment collision detection algorithm?
        # http://goo.gl/dk0yO1

        o = [-self.pos_x + p0[0], -self.pos_y + p0[1]]

        a = v0[0] ** 2 + v0[1] **2
        b = 2 * (o[0]*v0[0]+o[1]*v0[1])
        c = o[0] ** 2 + o[1] **2 - self.rad ** 2

        discriminant = float(b * b - 4 * a * c)

        if discriminant < 0:
            return [False, 1.0]

        discriminant = discriminant ** 0.5

        t1 = (- b - discriminant)/(2*a)
        t2 = (- b + discriminant)/(2*a)

        if t1 >= 0 and t1 <= 1.0:
            return [True, t1]

        if t2 >= 0 and t2 <= 1.0:
            return [True, t2]

        return [False, 1.0] 

class EYE(object):
    def __init__(self, i):
        self.OffSetAngle   = - math.pi/3 + i * math.pi*2/3/NUM_EYES
        self.SightDistance = 0
        self.obj           = -1
        self.FOV           = 130.0
        self.resetSightDistance()

    def resetSightDistance(self):
        self.SightDistance = self.FOV
        self.obj = -1

class Agent(Ball):
    def __init__(self, panel, x, y, alpha = 1.8):
        super(Agent, self).__init__(
            x, y, wx.Colour(112,146,190)
        )
        self.dir_Angle = math.pi/4
        self.speed     = 5

        self.pos_x_max, self.pos_y_max = panel.GetSize()
        self.pos_y_max = 480

        self.eyes = [ EYE(i) for i in range(0, NUM_EYES)]

        self.prevActions = np.zeros_like([])

        # Actor Critic Model
        self.theta = np.random.rand(2, STATE_DIM)
        self.sigma = np.array([[0.2],[0.2]])#np.random.rand(2, 1)
        self.w     = np.random.rand(1, STATE_DIM)
        
        
        # self.model
        self.modelQ = Q()
        self.optimizerQ = optimizers.Adam()
        self.optimizerQ.setup(self.modelQ)
        
        self.modelP = PolicyNetwork()
        self.optimizerP = optimizers.Adam()
        self.optimizerP.setup(self.modelP)

        # experience Memory
        self.eMem = np.array([],dtype = np.float32)
        self.memPos = 0
        self.memSize = 30000

        self.batch_num = 30
        self.loss = 0.0
        
        self.alpha = alpha
        self.gamma   = 0.95

        self.State = SState()
        self.prevState = np.ones((1,STATE_DIM))

    def UpdateState(self):
        s = np.ones((1, NUM_EYES * 3),dtype=np.float32)
        for i in range(0, NUM_EYES):
            if self.eyes[i].obj != -1:
                s[0, i * 3 + self.eyes[i].obj] = self.eyes[i].SightDistance / self.eyes[i].FOV
        self.State.push_s(s)

    def Draw(self, dc):
        dc.SetPen(wx.Pen(self.P_color))
        dc.SetBrush(wx.Brush(self.B_color))
        for e in self.eyes:
            if e.obj == 1:
                dc.SetPen(wx.Pen(wx.Colour(112,173,71)))
            elif e.obj == 2:
                dc.SetPen(wx.Pen(wx.Colour(237,125,49)))
            else:
                dc.SetPen(wx.Pen(self.P_color))

            dc.DrawLine(self.pos_x, self.pos_y, 
                    self.pos_x + e.SightDistance*math.cos(self.dir_Angle + e.OffSetAngle),
                    self.pos_y - e.SightDistance*math.sin(self.dir_Angle + e.OffSetAngle))

        super(Agent, self).Draw(dc)

    def get_action(self, state):
        x = Variable(state.reshape((1, -1)).astype(np.float32))
        y = self.modelP.predict(x)
        
        tmp = [max(min(y.data[0][0],0.9),-0.9),
               max(min(y.data[0][1],0.9),-0.9)]

        x = [tmp[0]+np.random.normal(0., self.alpha),
             tmp[1]+np.random.normal(0., self.alpha)]
        return x

    def reduce_alpha(self):
        self.alpha -= 1.0/2000
        self.alpha = max(0.02, self.alpha) 

    def Value(self, state):
        x = Variable(state.reshape((1, -1)).astype(np.float32))
        return self.modelQ.predict(x).data[0]

    def update_model(self, r, state, a):
        print a, r        
        if len(self.eMem)<self.batch_num:
            return

        memsize     = self.eMem.shape[0]
        batch_index = np.random.permutation(memsize)[:self.batch_num]
        batch       = np.array(self.eMem[batch_index], dtype=np.float32).reshape(self.batch_num, -1)

        p_seq = batch[:, 0:STATE_DIM]
        r = batch[:, STATE_DIM]
        a = batch[:, STATE_DIM:STATE_DIM+2]
        new_seq= batch[:,(STATE_DIM+2):(STATE_DIM*2+2)]        

        x = Variable( np.hstack((p_seq,a)).astype(np.float32))
        na = self.modelP.predict(Variable(new_seq)).data
        tmpQ = self.modelQ.predict(Variable(np.hstack((new_seq,na)))).data.copy()  
        targetsQ = r.reshape((self.batch_num,-1)) + self.gamma * tmpQ

        t = Variable(np.array(targetsQ, dtype=np.float32).reshape((self.batch_num,-1)))
        
        # Q関数更新
        self.modelQ.zerograds()
        loss=self.modelQ(x ,t)
        self.loss = loss.data
        loss.backward()
        self.optimizerQ.update()
        
        # Policy更新
        s = Variable(p_seq)
        self.modelP.zerograds()
        y = self.modelP.predict(s)
        y.grad = x.grad[:,-2:]
        y.backward()
        self.optimizerP.update()        

    def experience(self,x):
        if self.eMem.shape[0] > self.memSize:
            self.eMem[int(self.memPos%self.memSize)] = x
            self.memPos+=1
        elif self.eMem.shape[0] == 0:
            self.eMem = x
        else:
            self.eMem = np.vstack( (self.eMem, x) )        
        
    def Move(self, WallsList):
        flag = False
        dp = [ self.speed * math.cos(self.dir_Angle),
              -self.speed * math.sin(self.dir_Angle)]

        for w in WallsList:
            if w.IntersectLines([self.pos_x, self.pos_y], dp)[0]:
                dp = [0.0, 0.0]
                flag = True

        self.pos_x += dp[0] 
        self.pos_y += dp[1]

        self.pos_x = max(0, min(self.pos_x, self.pos_x_max))
        self.pos_y = max(0, min(self.pos_y, self.pos_y_max))
		
        reward = (dp[0]**2+dp[1]**2)**0.5
		
        return flag, reward

    def HitBall(self, b):
        if ((b.pos_x - self.pos_x)**2+(b.pos_y - self.pos_y)**2)**0.5 < (self.rad + b.rad):
            return True
        return False

class World(wx.Frame):
    def __init__(self, parent=None, id=-1, title=None):
        wx.Frame.__init__(self, parent, id, title)
        self.panel = wx.Panel(self, size=(640, 640))
        self.panel.SetBackgroundColour('WHITE')
        self.Fit()

        self.A = Agent(self.panel, 350, 150 )

        self.greenB = [Ball(rnd.randint(40, 600),rnd.randint(40, 440), 
                        wx.Colour(112,173,71), property =  1) for i in range(0, 15)]
        self.redB  = [Ball(rnd.randint(40, 600),rnd.randint(40, 440), 
                        wx.Colour(237,125,49), property = 2) for i in range(0, 10)]

        # OutrBox
        self.Box = Walls(640, 480, 0, 480)
        self.Box.addPoint(0,0)
        self.Box.addPoint(640,0)
        self.Box.addPoint(640,480)

        # Wall in the world
        #self.WallA = Walls(96, 90, 256, 90)
        #self.WallA.addPoint(256, 390)
        #self.WallA.addPoint(96,390)

        self.Bind(wx.EVT_CLOSE, self.CloseWindow)

        self.cdc = wx.ClientDC(self.panel)
        w, h = self.panel.GetSize()
        self.bmp = wx.EmptyBitmap(w,h)

        self.timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.OnTimer)
        self.timer.Start(20)

        # Plot
        # https://www.daniweb.com/programming/software-development/code/216913/using-wxpython-for-plotting
        #self.plotter = plot.PlotCanvas(self, pos=(0,480),size=(400,240))
        #self.plotter.SetEnableZoom(True)
        #self.plotter.SetEnableLegend(True)
        #self.plotter.SetFontSizeLegend(20)

    def CloseWindow(self, event):
        # self.timer.Stop()
        wx.Exit()

    def OnTimer(self, event):     
        # Update States
        self.A.UpdateState()
        state  = self.A.State.seq.reshape(1,-1)
        action = self.A.get_action(state)

        # Action Step
        action[0] = max(min(action[0], 1.0), -1.0)
        action[1] = max(min(action[1], 1.0), -1.0)

        self.A.speed = (action[0] + action[1])/2.0 * 5.0
        tmp = math.atan(max(min((action[0] - action[1])/ 2.0 / 2.5, 2.0), -2.0))
        self.A.dir_Angle += tmp
        self.A.dir_Angle = ( self.A.dir_Angle + np.pi) % (2 * np.pi ) - np.pi

        flag,rrr = self.A.Move([self.Box])

        digestion_reward = 0.0        
        for g in self.greenB:
            if self.A.HitBall(g):
                g.SetPos(rnd.randint(40, 600),rnd.randint(40, 440))
                digestion_reward -= 6.0
        for r in self.redB:
            if self.A.HitBall(r):
                r.SetPos(rnd.randint(40, 600),rnd.randint(40, 440))
                digestion_reward += 5.0

        # Reward
        proximity_reward = 0.0
        for e in self.A.eyes:
            proximity_reward += float(e.SightDistance)/e.FOV if e.obj == 0 else 1.0
        proximity_reward /= NUM_EYES
        proximity_reward = min(1.0, proximity_reward*2)

        forward_reward   = 0.0
        if(self.A.speed > 2.0 and proximity_reward > 0.75):
            forward_reward = 0.1
        elif(self.A.speed < 1.0):
            foward_reward = -0.2   
        if (self.A.speed > 0.0 and tmp > -np.pi/8 and tmp < np.pi/8):
            forward_reward += 0.1
            
        if flag:
            wall_punish = -1.0
        else:
            wall_punish = 0.0
        reward = proximity_reward + forward_reward + digestion_reward + wall_punish+rrr/2.0
        print ("reward:%.2f %.2f %.2f %.2f")%(proximity_reward,forward_reward
                                            ,digestion_reward, wall_punish+rrr/2.0)

        # Learning Step
        self.A.experience(np.hstack([
                        self.A.prevState,
                        np.array([reward]).reshape(1,-1),
                        np.array([action]).reshape(1,-1),
                        state
                    ]))
        self.A.update_model(reward, state, action)
        self.A.prevState = state.copy()
        self.A.reduce_alpha()

        # Graphics Update
        for e in self.A.eyes:
            e.resetSightDistance()
            p = [self.A.pos_x, self.A.pos_y]
            s = math.sin(self.A.dir_Angle + e.OffSetAngle)
            c = math.cos(self.A.dir_Angle + e.OffSetAngle)

            for g in self.greenB:
                f, t = g.IntersectBall(p, [e.SightDistance * c, - e.SightDistance * s])
                if f:
                    e.SightDistance *= t
                    e.obj = g.property

            for r in self.redB:
                f, t = r.IntersectBall(p, [e.SightDistance * c, - e.SightDistance * s])
                if f:
                    e.SightDistance *= t
                    e.obj = r.property

            for w in [self.Box]:
                f, t = w.IntersectLines(p, [e.SightDistance * c, - e.SightDistance * s])
                if f:
                    e.SightDistance *= t
                    e.obj = 0

        self.bdc = wx.BufferedDC(self.cdc, self.bmp)
        self.gcdc = wx.GCDC(self.bdc)
        self.gcdc.Clear()

        self.gcdc.SetPen(wx.Pen('white'))
        self.gcdc.SetBrush(wx.Brush('white'))
        self.gcdc.DrawRectangle(0,0,640,640)

        self.A.Draw(self.gcdc)
        for g in self.greenB:
            g.Draw(self.gcdc)
        for r in self.redB:
            r.Draw(self.gcdc)

        self.Box.Draw(self.gcdc)
        #self.WallA.Draw(self.gcdc)

        #line = plot.PolyLine(ext.extract(), colour='blue', width=1, legend=filename)
        #graph = plot.PlotGraphics([line], 'Wave', 't', 'f(t)')
        #self.plotter.Draw(graph, xAxis=(0,len(ext.wav)), yAxis=(ext.wav.minvalue, ext.wav.maxvalue))

if __name__ == '__main__':
    app = wx.PySimpleApp()
    w = World(title='RL Test')
    w.Center()
    w.Show()
    app.MainLoop()

参考

101
96
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
101
96

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?