class RainbowDQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(RainbowDQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[2], 32, kernel_size=8, stride=4),
#元のコード:nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc_val = nn.Sequential(
NoisyLinear(conv_out_size, 256),
nn.ReLU(),
NoisyLinear(256, N_ATOMS)
)
self.fc_adv = nn.Sequential(
NoisyLinear(conv_out_size, 256),
nn.ReLU(),
NoisyLinear(256, n_actions * N_ATOMS)
)
self.register_buffer("supports", torch.arange(Vmin, Vmax+DELTA_Z, DELTA_Z))
self.softmax = nn.Softmax(dim=1)
def _get_conv_out(self, shape):
new_shape = (shape[2], shape[0], shape[1]) #追加
o = self.conv(torch.zeros(1, *new_shape))
return int(np.prod(o.size()))
def forward(self, x):
x = torch.permute(x, (0, 3, 1, 2)) #追加
batch_size = x.size()[0]
fx = x.float() / 256
conv_out = self.conv(fx).reshape(batch_size, -1)
val_out = self.fc_val(conv_out).view(batch_size, 1, N_ATOMS)
adv_out = self.fc_adv(conv_out).view(batch_size, -1, N_ATOMS)
adv_mean = adv_out.mean(dim=1, keepdim=True)
return val_out + (adv_out - adv_mean)
def both(self, x):
cat_out = self(x)
probs = self.apply_softmax(cat_out)
weights = probs * self.supports
res = weights.sum(dim=2)
return cat_out, res
def qvals(self, x):
return self.both(x)[1]
def apply_softmax(self, t):
return self.softmax(t.view(-1, N_ATOMS)).view(t.size())