はじめに
この記事は最近PyTorchを0.4にアップデートした筆者がハマった問題について記述したものです.
同じようなケースに出くわした人のために役立てれば幸いです.
起きた問題
とある画像をCapsule networkを使って学習していたのですが,以前はうまくいっていた学習がバージョンアップ後にうまくいかなくなりました.
具体的には0.1 - 0.2くらいの値を出していたreconstruction lossがいきなり50000とか出すようになりました….当然,学習が破綻しました.
ちなみにCapsule networkの詳細についてはこの記事では説明を割愛させていただきます.原著論文や以下の他の記事・動画等を参照ください.中でも一番下の動画は中身を理解するのにオススメです.
CapsNet (Capsule Network) の PyTorch 実装
CapsNetについての調べ
Capsule Networks Explained(英語)
Capsule Networks (CapsNets) – Tutorial - YouTube(英語)
原因
調べてみると入力画像の輝度値の範囲がToTensor()で[0.0, 1.0]になっているかと思いきや,[0, 255]のままでした.
ここまでたどり着くのに数時間四苦八苦したのは内緒です.
>> inputs
array([[[ 77.],
[ 45.],
[-103.],
...,
[ -53.],
[ 111.],
[ -37.]],
...,
[[ 25.],
[ -87.],
[ -50.],
...,
[ 110.],
[ 112.],
[ 2.]]])
>>> inputs = transforms.ToTensor()(inputs)
>>> inputs
tensor([[[ 77., 45., -103., ..., -53., 111., -37.],
[-123., 8., -106., ..., 70., 4., 40.],
[ 5., -4., 43., ..., 103., -61., 24.],
...,
[ 59., 124., 89., ..., 82., -47., -24.],
[ 3., -127., -39., ..., 115., -64., 88.],
[ 25., -87., -50., ..., 110., 112., 2.]]],
dtype=torch.float64)
ここで該当するコードをみてみると…
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
いつの間にかtorch.ByteTensor以外は255で割られない仕様になっている….余計なことを!
ByteTensorというのはunsigned 8-bit integerなので,入力するndarrayがuint8でないと正しく[0.0, 1.0]の範囲に調整されないようです.
私のケースの場合,floatで与えてしまっていたので255で割られず,reconstructed image(こっちは[0.0, 1.0]の範囲になっている)との誤差がとんでもない値になってしまっていたようです.
とりあえず,ToTensor()に与える前にuint8に直すようにしました….
>>> inputs = inputs.astype(np.uint8)
>>> inputs = transforms.ToTensor()(inputs)
>>> inputs
tensor([[[0.5255, 0.4039, 0.2235, ..., 0.4275, 0.1020, 0.7294],
[0.4784, 0.6745, 0.6078, ..., 0.2353, 0.0235, 0.0824],
[0.1804, 0.6314, 0.0706, ..., 0.3294, 0.2627, 0.1059],
...,
[0.3725, 0.7098, 0.5608, ..., 0.4902, 0.0000, 0.6118],
[0.7922, 0.8431, 0.9373, ..., 0.8471, 0.2627, 0.9176],
[0.2706, 0.4235, 0.1922, ..., 0.7020, 0.9608, 0.1059]]])
何故こんな修正が?(というタイトルの愚痴)
この修正に関するPull requestをみてみると,
これってバグじゃなくて単に255で割られる仕様を意識できてなかっただけでは…(-_-;
開発者の人もAwesome!とか言ってすんなりmergeしないでよ….
ただ公式のドキュメントは[0.0, 1.0]に変換すると書いたままになっている.
これはちゃんとドキュメントにもその旨を書いたほうがいいのでは…?
本当はそういう記事を書く前にそういうIssueを投げた方がいいのだろうな.
おわりに
つまるところ型変換をちゃんとしていなかった私が悪いとも言えますね….でも同じところでハマる人も世の中にはいると思いますので,そういう方にとってこの記事が少しでも役に立てれば幸いです.