LoginSignup
ryo0606
@ryo0606

Are you sure you want to delete the question?

Leaving a resolved question undeleted may help others!

Gym環境での学習が上手くいかない

解決したいこと

Gymで提供されているPong-v4環境での学習が上手くいきません。Google Colaboratoryで実装しています。
私は参考としたコードに自分が変更を加えた、畳み込み層に問題があると考えています。

発生している問題・エラー

エラーは発生していないのですが、参考書と比べて学習に時間がかかりすぎています。Rainbowを用いた場合、2時間ほどで終わる学習が24時間程度、Noisy networkの実装では学習が一切進んでいません。
参考書通りにすると畳み込み層でエラーが発生したため、直したのですが、上手く作動していないように思えます。

該当するソースコード

class RainbowDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(RainbowDQN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[2], 32, kernel_size=8, stride=4),
            #元のコード:nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc_val = nn.Sequential(
            NoisyLinear(conv_out_size, 256),
            nn.ReLU(),
            NoisyLinear(256, N_ATOMS)
        )

        self.fc_adv = nn.Sequential(
            NoisyLinear(conv_out_size, 256),
            nn.ReLU(),
            NoisyLinear(256, n_actions * N_ATOMS)
        )

        self.register_buffer("supports", torch.arange(Vmin, Vmax+DELTA_Z, DELTA_Z))
        self.softmax = nn.Softmax(dim=1)

    def _get_conv_out(self, shape):
        new_shape = (shape[2], shape[0], shape[1]) #追加
        o = self.conv(torch.zeros(1, *new_shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        x = torch.permute(x, (0, 3, 1, 2)) #追加
        batch_size = x.size()[0]
        fx = x.float() / 256
        conv_out = self.conv(fx).reshape(batch_size, -1)
        val_out = self.fc_val(conv_out).view(batch_size, 1, N_ATOMS)
        adv_out = self.fc_adv(conv_out).view(batch_size, -1, N_ATOMS)
        adv_mean = adv_out.mean(dim=1, keepdim=True)
        return val_out + (adv_out - adv_mean)

    def both(self, x):
        cat_out = self(x)
        probs = self.apply_softmax(cat_out)
        weights = probs * self.supports
        res = weights.sum(dim=2)
        return cat_out, res

    def qvals(self, x):
        return self.both(x)[1]

    def apply_softmax(self, t):
        return self.softmax(t.view(-1, N_ATOMS)).view(t.size())

自分で試したこと

当初のエラーを見たところ、観測空間が(3 * 210 * 160)の形(チャネル数×縦×横)で出力される必要がある、というようなエラーが出ました。しかしGymでの出力は(縦×横×チャネル数)、(210 * 160 * 3)であるため、引数を参照する場所を変えたり、permute関数を使うことで次元の変更を行いました。これによりエラーは無くなったのですが、学習は上手くいっていないです。
またpermute関数の挙動を自分で確認したのですが、僕の思った通りに動いてること(正しく次元の変更ができている、多分)も確認しました。
そもそも元のコードで上手くいくことが考えにくいのですが、ここ5年ほどでGymの出力する次元の変更でもあったのでしょうか?もしくはwrapperなどで出力の次元を変更する動作があるのでしょうか?

参考書

元のコード

0

1Answer

そもそも元のコードで上手くいくことが考えにくいのですが、ここ5年ほどでGymの出力する次元の変更でもあったのでしょうか?もしくはwrapperなどで出力の次元を変更する動作があるのでしょうか?

元のコードを見ました。
コード内の82行目で、wrapperが適用されています。

env = gym.make(params.env_name)
env = ptan.common.wrappers.wrap_dqn(env)
env.seed(common.SEED)

また、該当のwrapperのコードを見たところ、出力の次元を変更する処理が追加されているようです(もちろんその他の処理も)。

上記を踏まえて、今回発生している問題の原因究明にあたって、まずは次の2点を調べると良いかと思われます。

  • GPU(TPU?)で動作しているか
  • 参考書通りの処理を実装できているか
1

Comments

  1. @ryo0606

    Questioner

    ご返信ありがとうございます。また返信が遅れて申し訳ありません。丁度そこは入れるとよく分からないエラーが出るところであり、cartpoleなどの観測空間が画像ではない環境だと必要なかったため、コードから省いている箇所でした。もう一度出ていたエラーをよく確認し、wrapper部分を上手く取り込んだところ、上手く動くようになりました。ありがとうございました。

Your answer might help someone💌