LoginSignup
1
0

More than 5 years have passed since last update.

CNTKの強化学習用パッケージdeeprlの使い方

Last updated at Posted at 2017-11-02

始めに

deeprlは、CNTK v.2.1で追加された強化学習用パッケージです。
(contribパッケージのサブパッケージに当たります)
今回は、CNTK v.2.2時点のdeeprlについての簡単な解説と、サンプルコードを紹介します。

※deeprlは最近追加されたばかりという事もあり、v.2.2以降のアップデートで大幅な変更が入ると
 予想されます。今後、記事の内容との相違が発生する可能性がありますので、ご注意下さい。

deeprlでできること

現在、deeprlには以下の強化学習手法が実装されています。
https://github.com/Microsoft/CNTK/tree/master/bindings/python/cntk/contrib/deeprl

  • Tabular Q-learning(Q関数をQテーブルで表現)
  • DQN
  • Double-DQN
  • Actor-Critic

deeprlサブパッケージの中身

APIリファレンスを参照して、deeprlサブパッケージの構成と機能を確認してみます。

Subpackages Submodules 機能
agent agent agentの基底クラス
agent agent_factory configファイルを読み込み、agentインスタンスを生成
agent policy_gradient Actor-Critic agentを定義
agent qlearning Q-Learning agentを定義(Q関数をNNで表現)
agent random_agent ランダムな行動を取るagentを定義
agent tabular_qlearning Q-Learning agentを定義(Q関数をQテーブルで表現)
agent.shared cntk_utils hugelossなどの損失関数を定義
agent.shared customized_models Nature版DQN等の有名ネットワークを定義
agent.shared discretize 連続な状態空間を離散値に変換
agent.shared models feedforward network, dueling networkを定義
agent.shared policy_gradient_parameters configファイルからパラメータを読み込む(Policy Grandient用)
agent.shared preprocessing データの前処理を定義
agent.shared qlearning_parameters configファイルからパラメータを読み込む(Q Learning用)
agent.shared replay_memory ReplayMemoryを定義
tests 単体テスト用

サンプルコード

deeprlのDouble-DQN agentを使い、CartPole問題を解いてみます。

import os
import gym
from cntk.contrib.deeprl.agent import qlearning

まずは、gymパッケージを使い、CartPole環境を読み込みます。

env = gym.make('CartPole-v0')
o_space = env.observation_space
a_space= env.action_space

次はagentの生成になります。今回はDouble-DQNによる学習を行うので、
qlearningサブモジュールのQLearningクラスを使い、agentインスタンスを生成します。

# Create an agent instance
agent = qlearning.QLearning(config_filename='config_examples/qlearning.config', 
                            o_space=o_space, 
                            a_space=a_space)
# Save parameter setting to file.
agent.save_parameter_settings('./save/qlearning.params')

学習前の処理は以上で終わりです。
というのは、学習設定は全てconfigファイルに保存し、それを読み込ませているからです。
実装されている手法毎にconfigファイル例が用意されているため、DQN以外の手法を使う場合は
そちらを参照してください。
今回の計算では、以下の内容のconfigファイルを使用しています。

qlearning.config
# See cntk.contrib.deeprl.agent.shared.qlearning_parameters for detailed
# explanation of each parameter.

[General]
Agent = qlearning
Gamma = 0.95
# PreProcessing = cntk.contrib.deeprl.agent.shared.preprocessing.AtariPreprocessing
# PreProcessingArgs = (4,)

[QLearningAlgo]
InitialEpsilon = 1.00
EpsilonDecayStepCount = 10000
EpsilonMinimum = 0.10
InitialQ = 0.0
TargetQUpdateFrequency = 100
QUpdateFrequency = 1
MinibatchSize = 16
# QRepresentation can be 'dqn', 'dueling-dqn', or some customized model defined as
# module_name.method_name, e.g.
# QRepresentation = cntk.contrib.deeprl.agent.shared.customized_models.conv_dqn
QRepresentation = dqn
ErrorClipping = False
ReplaysPerUpdate = 1

[ExperienceReplay]
Capacity = 10000
StartSize = 500
Prioritized = True
PriorityAlpha = 0.7
PriorityBeta = 1
PriorityEpsilon = 0.0001

[NetworkModel]
# Use (a list of integers) when QRepresentation is 'dqn'
HiddenLayerNodes = [100]

# Or use (a list of integers followed by two lists of integers) when
# QRepresentation is 'dueling-dqn'
# HiddenLayerNodes = [10, [5], [5]]

[Optimization]
Momentum = 0.9
InitialEta = 0.01
EtaDecayStepCount = 10000
EtaMinimum = 0.001
GradientClippingThreshold = 10

次に学習を実行します。
agent.start()で学習をスタートし、agent.step()でネットワークを更新していきます。
ゲームが終了したら、agent.end()で学習を終了させます。

n_episodes = 200
max_episode_len = 200
global_step = 0

for i in range(1, n_episodes+1):
    observation = env.reset()
    isTerminal = False
    sum_rewards = 0
    local_step = 0

    # Start a new episode.
    action, debug_info = agent.start(observation)

    while local_step < max_episode_len:
        observation, reward, isTerminal, _ = env.step(action)
        sum_rewards += reward
        local_step += 1
        global_step += 1

        if isTerminal:
            # Last observed reward/state of the episode.
            agent.end(reward, observation)
            break
        else:
            # Observe one transition and choose an action.
            action, debug_info = agent.step(reward, observation)

    if i % 10 == 0:
        print('episode:', i,
              'local_step:', local_step,
              'global_step:', global_step,
              'sum_rewards:', sum_rewards,
              'epsilon:', debug_info.get('epsilon'))

学習が終了したら、モデルを保存します。
agent.set_as_best_model()で、その時点のモデルを内部でクローンしておき、
直近にクローンしたモデルをagent.save()で保存します。

agent.set_as_best_model()
agent.save('./save/qlearning.model')

最後に、学習済みモデルをテストしてみます。
agent.evaluate()で、ネットワークのパラメータを更新させずに行動を受け取ることができます。
agent.enter_evaluation()はε-greedyのε値を0にするメソッドです。
ε値を0以外に指定したい場合は、直接インスタンス変数 agent._epsilonを更新してください。

n_episodes = 10
max_episode_len = 200
agent.enter_evaluation()
#agent._epsilon = 0.05

for i in range(1, n_episodes+1):
    observation = env.reset()
    isTerminal = False
    sum_rewards = 0
    local_step = 0

    while not isTerminal and local_step < max_episode_len:
        # Choose an action acoording to the policy.
        action = agent.evaluate(observation)
        observation, reward, isTerminal, _ = env.step(action)
        sum_rewards += reward
        local_step += 1

    print('episode:', i,
          'local_step:', local_step,
          'sum_rewards:', sum_rewards)

agent.exit_evaluation()

学習結果

episode: 10 local_step: 22 global_step: 220 sum_rewards: 22.0 epsilon: 0.98029
episode: 20 local_step: 17 global_step: 413 sum_rewards: 17.0 epsilon: 0.96292
episode: 30 local_step: 8 global_step: 582 sum_rewards: 8.0 epsilon: 0.9477099999999999
episode: 40 local_step: 22 global_step: 728 sum_rewards: 22.0 epsilon: 0.93457
episode: 50 local_step: 19 global_step: 916 sum_rewards: 19.0 epsilon: 0.91765
episode: 60 local_step: 58 global_step: 1224 sum_rewards: 58.0 epsilon: 0.88993
episode: 70 local_step: 22 global_step: 1521 sum_rewards: 22.0 epsilon: 0.8632
episode: 80 local_step: 21 global_step: 1815 sum_rewards: 21.0 epsilon: 0.83674
episode: 90 local_step: 11 global_step: 2240 sum_rewards: 11.0 epsilon: 0.79849
episode: 100 local_step: 114 global_step: 2802 sum_rewards: 114.0 epsilon: 0.74791
episode: 110 local_step: 39 global_step: 3209 sum_rewards: 39.0 epsilon: 0.71128
episode: 120 local_step: 46 global_step: 3619 sum_rewards: 46.0 epsilon: 0.67438
episode: 130 local_step: 110 global_step: 4223 sum_rewards: 110.0 epsilon: 0.62002
episode: 140 local_step: 190 global_step: 5587 sum_rewards: 190.0 epsilon: 0.49726000000000004
episode: 150 local_step: 186 global_step: 6903 sum_rewards: 186.0 epsilon: 0.37881999999999993
episode: 160 local_step: 200 global_step: 8789 sum_rewards: 200.0 epsilon: 0.20908
episode: 170 local_step: 200 global_step: 10781 sum_rewards: 200.0 epsilon: 0.1
episode: 180 local_step: 178 global_step: 12715 sum_rewards: 178.0 epsilon: 0.1
episode: 190 local_step: 200 global_step: 14662 sum_rewards: 200.0 epsilon: 0.1
episode: 200 local_step: 199 global_step: 16532 sum_rewards: 199.0 epsilon: 0.1
Finished.

テスト結果

episode: 1 local_step: 200 sum_rewards: 200.0
episode: 2 local_step: 200 sum_rewards: 200.0
episode: 3 local_step: 200 sum_rewards: 200.0
episode: 4 local_step: 200 sum_rewards: 200.0
episode: 5 local_step: 200 sum_rewards: 200.0
episode: 6 local_step: 200 sum_rewards: 200.0
episode: 7 local_step: 200 sum_rewards: 200.0
episode: 8 local_step: 200 sum_rewards: 200.0
episode: 9 local_step: 200 sum_rewards: 200.0
episode: 10 local_step: 200 sum_rewards: 200.0

以上

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0