Q学習のプログラム作成中に、エラーが発生しました。
def get_target(reword, gamma, next_q_max, done) do
reword + (1-done) * gamma * next_q_max
end
この関数に与えた値が、Nx.Tensorだったので、+ や *で演算ができなかったのです。
こんな時には、defnを使って関数を定義します。
算術記号を使った記述をNxの演算関数に置き換えてくれるので、うまくいきます。
defn get_target(reword, gamma, next_q_max, done) do
reword + (1-done) * gamma * next_q_max
end
defで記述するとしたら、次のような記述ですね。
|>演算子のおかげで、これでもまあ、読めなくはない。
def get_target_nx(reword, gamma, next_q_max, done) do
Nx.subtract(1, done)
|> Nx.multiply(gamma)
|> Nx.multiply(next_q_max)
|> Nx.add(reword)
end
算術演算子で書けた方がわかりやすいので、defn使いましょう!