#Pytorchでsemantic segmentationして困ったこと
卒論書くときにpytorchでsemantic segmentationをしたのだが, いろいろ困ったことので備忘録的にまとめておく.
IoU
torchvisionの方にbox_iouというものがあるのだが, segmentation用の奴が見つからなかったので困った
結局この実装を使った
torch.nn.CrossEntropyLoss
こいつがマジで厄介
documentを見ると$Input:[Minibatch, C, d_1...]$とか書いてたので推論した結果と正解データのマスクのshapeを[バッチ数, クラス数, H, W]みたいな感じのone-hot表現にしてたら無限にエラーを吐いてきて1,2日詰まった.
実際には推論結果のshapeは上記でいいが, 正解データの方は$target:[Minibatch, d_1...]$という感じらしく, shapeが[バッチ数, H, W]で各ピクセルにラベルがついている状態が正解らしい.
BinaryCrossEntropyがinput, target共に同じshapeだったので合わせていたら違って腹が立った.ちゃんと確認しろ
あとtargetの方は各ピクセルのラベルがinputのクラスのindexとあっていなければいけないらしい. つまりクラス数が20個なら, targetのラベルは0~19の値でなくてはいけない.