Gym環境での学習が上手くいかない
解決したいこと
Gymで提供されているPong-v4環境での学習が上手くいきません。Google Colaboratoryで実装しています。
私は参考としたコードに自分が変更を加えた、畳み込み層に問題があると考えています。
発生している問題・エラー
エラーは発生していないのですが、参考書と比べて学習に時間がかかりすぎています。Rainbowを用いた場合、2時間ほどで終わる学習が24時間程度、Noisy networkの実装では学習が一切進んでいません。
参考書通りにすると畳み込み層でエラーが発生したため、直したのですが、上手く作動していないように思えます。
該当するソースコード
class RainbowDQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(RainbowDQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[2], 32, kernel_size=8, stride=4),
#元のコード:nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc_val = nn.Sequential(
NoisyLinear(conv_out_size, 256),
nn.ReLU(),
NoisyLinear(256, N_ATOMS)
)
self.fc_adv = nn.Sequential(
NoisyLinear(conv_out_size, 256),
nn.ReLU(),
NoisyLinear(256, n_actions * N_ATOMS)
)
self.register_buffer("supports", torch.arange(Vmin, Vmax+DELTA_Z, DELTA_Z))
self.softmax = nn.Softmax(dim=1)
def _get_conv_out(self, shape):
new_shape = (shape[2], shape[0], shape[1]) #追加
o = self.conv(torch.zeros(1, *new_shape))
return int(np.prod(o.size()))
def forward(self, x):
x = torch.permute(x, (0, 3, 1, 2)) #追加
batch_size = x.size()[0]
fx = x.float() / 256
conv_out = self.conv(fx).reshape(batch_size, -1)
val_out = self.fc_val(conv_out).view(batch_size, 1, N_ATOMS)
adv_out = self.fc_adv(conv_out).view(batch_size, -1, N_ATOMS)
adv_mean = adv_out.mean(dim=1, keepdim=True)
return val_out + (adv_out - adv_mean)
def both(self, x):
cat_out = self(x)
probs = self.apply_softmax(cat_out)
weights = probs * self.supports
res = weights.sum(dim=2)
return cat_out, res
def qvals(self, x):
return self.both(x)[1]
def apply_softmax(self, t):
return self.softmax(t.view(-1, N_ATOMS)).view(t.size())
自分で試したこと
当初のエラーを見たところ、観測空間が(3 * 210 * 160)の形(チャネル数×縦×横)で出力される必要がある、というようなエラーが出ました。しかしGymでの出力は(縦×横×チャネル数)、(210 * 160 * 3)であるため、引数を参照する場所を変えたり、permute関数を使うことで次元の変更を行いました。これによりエラーは無くなったのですが、学習は上手くいっていないです。
またpermute関数の挙動を自分で確認したのですが、僕の思った通りに動いてること(正しく次元の変更ができている、多分)も確認しました。
そもそも元のコードで上手くいくことが考えにくいのですが、ここ5年ほどでGymの出力する次元の変更でもあったのでしょうか?もしくはwrapperなどで出力の次元を変更する動作があるのでしょうか?
参考書
元のコード