を解説
_main()
諸々初期化
seed など初期化。
logger.info('The first `gym.make(MineRL*)` may take several minutes. Be patient!')
os.environ['MALMO_MINECRAFT_OUTPUT_LOGDIR'] = args.outdir
# Set a random seed used in ChainerRL.
chainerrl.misc.set_random_seed(args.seed)
# Set different random seeds for train and test envs.
train_seed = args.seed # noqa: never used in this script
test_seed = 2 ** 31 - 1 - args.seed
環境作成
gym.make()
で環境作成
core_env = gym.make(args.env)
env = wrap_env(core_env, test=False)
eval_env = wrap_env(core_env, test=True)
Q関数
parse_arch()
で作成(後述)
# Q function
n_actions = env.action_space.n
q_func = parse_arch(args.arch, n_actions, n_input_channels=env.observation_space.shape[0])
探索の方法を設定
--noisy-net-sigma
で、noisy netを使うか設定
# explorer
if args.noisy_net_sigma is not None:
chainerrl.links.to_factorized_noisy(q_func, sigma_scale=args.noisy_net_sigma)
# Turn off explorer
explorer = chainerrl.explorers.Greedy()
else:
explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
1.0, args.final_epsilon, args.final_exploration_frames, env.action_space.sample)
計算グラフを描画して画像に保存
計算グラフを描画して画像に保存。(特に必要無く、計算グラフの確認のために使う)
# Draw the computational graph and save it in the output directory.
sample_obs = env.observation_space.sample()
sample_batch_obs = np.expand_dims(sample_obs, 0)
chainerrl.misc.draw_computational_graph([q_func(sample_batch_obs)], os.path.join(args.outdir, 'model'))
Optimizerの設定
Optimizerとは、関数最適化手法。このパラメーターは、ハイパーパラメーター。
# Use the Nature paper's hyperparameters
# opt = optimizers.RMSpropGraves(lr=args.lr, alpha=0.95, momentum=0.0, eps=1e-2)
opt = chainer.optimizers.Adam(alpha=args.lr, eps=args.adam_eps) # NOTE: mirrors DQN implementation in MineRL paper
opt.setup(q_func)
ステップ数の計算
maximum_framesを8,640,000として、frame skip をする場合としない場合で、それぞれstepsとeval_intervalを計算。
# calculate corresponding `steps` and `eval_interval` according to frameskip
# = 1440 episodes if we count an episode as 6000 frames,
# = 1080 episodes if we count an episode as 8000 frames.
maximum_frames = 8640000
if args.frame_skip is None:
steps = maximum_frames
eval_interval = 6000 * 100 # (approx.) every 100 episode (counts "1 episode = 6000 steps")
else:
steps = maximum_frames // args.frame_skip
eval_interval = 6000 * 100 // args.frame_skip # (approx.) every 100 episode (counts "1 episode = 6000 steps")
replay buffer の設定
# Select a replay buffer to use
if args.prioritized:
# Anneal beta from beta0 to 1 throughout training
betasteps = steps / args.update_interval
rbuf = chainerrl.replay_buffer.PrioritizedReplayBuffer(
args.replay_capacity, alpha=0.5, beta0=0.4, betasteps=betasteps, num_steps=args.num_step_return)
else:
rbuf = chainerrl.replay_buffer.ReplayBuffer(args.replay_capacity, args.num_step_return)
Agentの作成
phiは、$\phi$関数。対した意味は無く、ステップ毎の前処理。
parse_agent()
でAgent作成(後述)
# build agent
def phi(x):
# observation -> NN input
return np.asarray(x)
Agent = parse_agent(args.agent)
agent = Agent(
q_func, opt, rbuf, gpu=args.gpu, gamma=args.gamma, explorer=explorer, replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval, clip_delta=args.clip_delta, update_interval=args.update_interval,
batch_accumulator=args.batch_accumulator, phi=phi)
load
if args.load:
agent.load(args.load)
学習 or 実行
# experiment
if args.demo:
eval_stats = chainerrl.experiments.eval_performance(env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs)
logger.info('n_runs: {} mean: {} median: {} stdev {}'.format(
args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
else:
chainerrl.experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=steps,
eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=eval_interval,
outdir=args.outdir, eval_env=eval_env, save_best_so_far_agent=True,
)
env.close()
eval_env.close()
wrap_env()
def wrap_env(env, test):
# wrap env: time limit...
if isinstance(env, gym.wrappers.TimeLimit):
logger.info('Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.')
env = env.env
max_episode_steps = env.spec.max_episode_steps
env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)
# wrap env: observation...
# NOTE: wrapping order matters!
if test and args.monitor:
env = ContinuingTimeLimitMonitor(
env, os.path.join(args.outdir, env.spec.id, 'monitor'),
mode='evaluation' if test else 'training', video_callable=lambda episode_id: True)
if args.frame_skip is not None:
env = FrameSkip(env, skip=args.frame_skip)
if args.gray_scale:
env = GrayScaleWrapper(env, dict_space_key='pov')
if args.env.startswith('MineRLNavigate'):
env = PoVWithCompassAngleWrapper(env)
else:
env = ObtainPoVWrapper(env)
env = MoveAxisWrapper(env, source=-1, destination=0) # convert hwc -> chw as Chainer requires.
env = ScaledFloatFrame(env)
if args.frame_stack is not None and args.frame_stack > 0:
env = FrameStack(env, args.frame_stack, channel_order='chw')
# wrap env: action...
if not args.disable_action_prior:
env = SerialDiscreteActionWrapper(
env,
always_keys=args.always_keys, reverse_keys=args.reverse_keys, exclude_keys=args.exclude_keys, exclude_noop=args.exclude_noop)
else:
env = CombineActionWrapper(env)
env = SerialDiscreteCombineActionWrapper(env)
if test and args.noisy_net_sigma is None:
env = RandomizeAction(env, args.eval_epsilon)
env_seed = test_seed if test else train_seed
# env.seed(int(env_seed)) # TODO: not supported yet
return env
parse_agent()
def parse_agent(agent):
return {'DQN': chainerrl.agents.DQN,
'DoubleDQN': chainerrl.agents.DoubleDQN,
'PAL': chainerrl.agents.PAL,
'CategoricalDoubleDQN': chainerrl.agents.CategoricalDoubleDQN}[agent]
parse_arch()
def parse_arch(arch, n_actions, n_input_channels):
if arch == 'dueling':
# Conv2Ds of (channel, kernel, stride): [(32, 8, 4), (64, 4, 2), (64, 3, 1)]
return DuelingDQN(n_actions, n_input_channels=n_input_channels, hiddens=[256])
elif arch == 'distributed_dueling':
n_atoms = 51
v_min = -10
v_max = 10
return DistributionalDuelingDQN(n_actions, n_atoms, v_min, v_max, n_input_channels=n_input_channels)
else:
raise RuntimeError('Unsupported architecture name: {}'.format(arch))