問題
pyroで変分ベイズを触っているときに以下のようなエラーが出た。
(" "と[ ]の中身はそれぞれ違うと思います。)
error
ValueError: at site "y", invalid log_prob shape
Expected [], actual [200]
Try one of the following fixes:
- enclose the batched tensor in a with plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
解決
あまりよくわからないけど、モデル設計の以下の部分と近似分布の形状が合っていない?っぽい。
model
pm_Y = pyro.sample('pm_Y', dist.Normal(mu, sigma), obs=Y)
これを以下のようにしたら解決した。
model
pm_Y = pyro.sample('pm_Y', dist.Normal(mu, sigma).to_event(1), obs=Y)
参考文献
以下はpyroのテンソルについて詳しく書いていたもの。
https://torch.classcat.com/2020/07/15/pyro-1-3-tensor-shapes/