6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

MineRLのbaselinesのソース説明

Last updated at Posted at 2019-10-06

を解説

_main()

諸々初期化

seed など初期化。

    logger.info('The first `gym.make(MineRL*)` may take several minutes. Be patient!')

    os.environ['MALMO_MINECRAFT_OUTPUT_LOGDIR'] = args.outdir

    # Set a random seed used in ChainerRL.
    chainerrl.misc.set_random_seed(args.seed)

    # Set different random seeds for train and test envs.
    train_seed = args.seed  # noqa: never used in this script
    test_seed = 2 ** 31 - 1 - args.seed

環境作成

gym.make() で環境作成

    core_env = gym.make(args.env)
    env = wrap_env(core_env, test=False)
    eval_env = wrap_env(core_env, test=True)

Q関数

parse_arch()で作成(後述)

    # Q function
    n_actions = env.action_space.n
    q_func = parse_arch(args.arch, n_actions, n_input_channels=env.observation_space.shape[0])

探索の方法を設定

--noisy-net-sigma で、noisy netを使うか設定

    # explorer
    if args.noisy_net_sigma is not None:
        chainerrl.links.to_factorized_noisy(q_func, sigma_scale=args.noisy_net_sigma)
        # Turn off explorer
        explorer = chainerrl.explorers.Greedy()
    else:
        explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
            1.0, args.final_epsilon, args.final_exploration_frames, env.action_space.sample)

計算グラフを描画して画像に保存

計算グラフを描画して画像に保存。(特に必要無く、計算グラフの確認のために使う)

    # Draw the computational graph and save it in the output directory.
    sample_obs = env.observation_space.sample()
    sample_batch_obs = np.expand_dims(sample_obs, 0)
    chainerrl.misc.draw_computational_graph([q_func(sample_batch_obs)], os.path.join(args.outdir, 'model'))

Optimizerの設定

Optimizerとは、関数最適化手法。このパラメーターは、ハイパーパラメーター。


    # Use the Nature paper's hyperparameters
    # opt = optimizers.RMSpropGraves(lr=args.lr, alpha=0.95, momentum=0.0, eps=1e-2)
    opt = chainer.optimizers.Adam(alpha=args.lr, eps=args.adam_eps)  # NOTE: mirrors DQN implementation in MineRL paper

    opt.setup(q_func)

ステップ数の計算

maximum_framesを8,640,000として、frame skip をする場合としない場合で、それぞれstepsとeval_intervalを計算。

    # calculate corresponding `steps` and `eval_interval` according to frameskip
    # = 1440 episodes if we count an episode as 6000 frames,
    # = 1080 episodes if we count an episode as 8000 frames.
    maximum_frames = 8640000
    if args.frame_skip is None:
        steps = maximum_frames
        eval_interval = 6000 * 100  # (approx.) every 100 episode (counts "1 episode = 6000 steps")
    else:
        steps = maximum_frames // args.frame_skip
        eval_interval = 6000 * 100 // args.frame_skip  # (approx.) every 100 episode (counts "1 episode = 6000 steps")

replay buffer の設定

    # Select a replay buffer to use
    if args.prioritized:
        # Anneal beta from beta0 to 1 throughout training
        betasteps = steps / args.update_interval
        rbuf = chainerrl.replay_buffer.PrioritizedReplayBuffer(
            args.replay_capacity, alpha=0.5, beta0=0.4, betasteps=betasteps, num_steps=args.num_step_return)
    else:
        rbuf = chainerrl.replay_buffer.ReplayBuffer(args.replay_capacity, args.num_step_return)

Agentの作成

phiは、$\phi$関数。対した意味は無く、ステップ毎の前処理。
parse_agent()でAgent作成(後述)

    # build agent
    def phi(x):
        # observation -> NN input
        return np.asarray(x)
    Agent = parse_agent(args.agent)
    agent = Agent(
        q_func, opt, rbuf, gpu=args.gpu, gamma=args.gamma, explorer=explorer, replay_start_size=args.replay_start_size,
        target_update_interval=args.target_update_interval, clip_delta=args.clip_delta, update_interval=args.update_interval,
        batch_accumulator=args.batch_accumulator, phi=phi)

load

    if args.load:
        agent.load(args.load)

学習 or 実行

    # experiment
    if args.demo:
        eval_stats = chainerrl.experiments.eval_performance(env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs)
        logger.info('n_runs: {} mean: {} median: {} stdev {}'.format(
            args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
    else:
        chainerrl.experiments.train_agent_with_evaluation(
            agent=agent, env=env, steps=steps,
            eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=eval_interval,
            outdir=args.outdir, eval_env=eval_env, save_best_so_far_agent=True,
        )

    env.close()
    eval_env.close()

wrap_env()

    def wrap_env(env, test):
        # wrap env: time limit...
        if isinstance(env, gym.wrappers.TimeLimit):
            logger.info('Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.')
            env = env.env
            max_episode_steps = env.spec.max_episode_steps
            env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)

        # wrap env: observation...
        # NOTE: wrapping order matters!

        if test and args.monitor:
            env = ContinuingTimeLimitMonitor(
                env, os.path.join(args.outdir, env.spec.id, 'monitor'),
                mode='evaluation' if test else 'training', video_callable=lambda episode_id: True)
        if args.frame_skip is not None:
            env = FrameSkip(env, skip=args.frame_skip)
        if args.gray_scale:
            env = GrayScaleWrapper(env, dict_space_key='pov')
        if args.env.startswith('MineRLNavigate'):
            env = PoVWithCompassAngleWrapper(env)
        else:
            env = ObtainPoVWrapper(env)
        env = MoveAxisWrapper(env, source=-1, destination=0)  # convert hwc -> chw as Chainer requires.
        env = ScaledFloatFrame(env)
        if args.frame_stack is not None and args.frame_stack > 0:
            env = FrameStack(env, args.frame_stack, channel_order='chw')

        # wrap env: action...
        if not args.disable_action_prior:
            env = SerialDiscreteActionWrapper(
                env,
                always_keys=args.always_keys, reverse_keys=args.reverse_keys, exclude_keys=args.exclude_keys, exclude_noop=args.exclude_noop)
        else:
            env = CombineActionWrapper(env)
            env = SerialDiscreteCombineActionWrapper(env)

        if test and args.noisy_net_sigma is None:
            env = RandomizeAction(env, args.eval_epsilon)

        env_seed = test_seed if test else train_seed
        # env.seed(int(env_seed))  # TODO: not supported yet
        return env

parse_agent()

def parse_agent(agent):
    return {'DQN': chainerrl.agents.DQN,
            'DoubleDQN': chainerrl.agents.DoubleDQN,
            'PAL': chainerrl.agents.PAL,
            'CategoricalDoubleDQN': chainerrl.agents.CategoricalDoubleDQN}[agent]

parse_arch()

def parse_arch(arch, n_actions, n_input_channels):
    if arch == 'dueling':
        # Conv2Ds of (channel, kernel, stride): [(32, 8, 4), (64, 4, 2), (64, 3, 1)]
        return DuelingDQN(n_actions, n_input_channels=n_input_channels, hiddens=[256])
    elif arch == 'distributed_dueling':
        n_atoms = 51
        v_min = -10
        v_max = 10
        return DistributionalDuelingDQN(n_actions, n_atoms, v_min, v_max, n_input_channels=n_input_channels)
    else:
        raise RuntimeError('Unsupported architecture name: {}'.format(arch))
6
1
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?