結論から言うと、画像要素はそれぞれこのような順番の違いがある
- OpenCV: [高さ、幅、色チャネル]
- PIL: [高さ、幅、色チャネル]
- PyTorch: [色チャネル、高さ、幅]
故に、PyTorchでは必ず
img_transformed = img_transformed.numpy().transpose((1, 2, 0))
のように要素を入れ替える必要があるので注意
通常のnumpy3次元配列とOpenCVやPILの3次元配列のイメージの違い
通常のnumpy3次元配列はこのようなイメージだと思う
通常のnumpy3次元配列
X = np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]],
[[13, 14, 15]],
[16, 17, 18]]])
要素の取り出し
# 最初の数字が三次元目の軸(奥行き方向)、2番目が二次元目の軸(縦)、3番目が一次元目(横)
>>> X[0, 1, 2]
6
# 奥行き軸(axis=0)の0番目を取り出す
>>> X[0, :, :]
array([[1, 2, 3],
[4, 5, 6]])
# 縦軸(axis=1)の1番目を取り出す
>>> X[:, 1, :]
array([[ 4, 5, 6],
[10, 11, 12]])
# 横軸(axis=2)の2番目を取り出す
>>> X[:, :, 2]
array([[ 3, 6],
[ 9, 12]])
[奥行き=RGBチャネル, 高さ, 幅] の順なのでOpenCVやPILで読み込んだ画像をPyTorchに変換するとこのイメージと同じになる
しかしその前に一般的に画像処理に使われるライブラリOpenCV, PILでは同じ3次元ではあるのだが、配列構成のイメージが違う
OpenCV, PILの3次元配列
OpenCVやPILなどは、 [高さ, 幅、奥行き=RGBチャネル] の順になる(OpenCVはBGRの順なので注意)
import cv2
img = cv2.imread('img.png')
print(img.shape)
> (172, 180, 3)
print(img)
>
[[[ 58 35 33]
[ 57 33 30]
[ 54 29 25]
...
[ 7 6 8]
[ 8 8 9]
[ 10 10 11]]
[[ 60 36 34]
[ 60 34 31]
[ 56 30 26]
...
[ 6 6 8]
[ 8 7 9]
[ 10 9 10]]
[[ 61 37 34]
[ 62 36 32]
[ 60 33 29]
...
[ 6 5 7]
[ 7 7 8]
[ 9 9 10]]
要素の取り出し
# 行(高さ)を取得
print(img[0, :, :].shape)
print(img[0, :, :])
>
(180, 3)
[[58 35 33]
[57 33 30]
[54 29 25]
...
[ 8 8 9]
[10 10 11]]
# 列(幅)を取得
print(img[:, 0, :].shape)
print(img[:, 0, :])
>
(172, 3)
[[ 58 35 33]
[ 60 36 34]
[ 61 37 34]
[ 62 36 34]
...
[139 89 69]
[121 72 53]]
# 奥行き(チャネル)を取得
print(img[:, :, 0].shape)
print(img[:, :, 0])
>
(172, 180)
[[ 58 57 54 ... 7 8 10]
[ 60 60 56 ... 6 8 10]
[ 61 62 60 ... 6 7 9]
...
[157 154 144 ... 20 23 24]
[139 134 124 ... 22 24 25]
[121 116 107 ... 22 23 25]]
[高さ, 幅, 奥行き=RGBチャンル] の順なので次元ごとの要素の取り出し方が通常のnumpy及びPyTorchとは異なることに注意
高さと幅のイメージが直感的に異なるが実態はこうなっている。
なぜこうなっているかだが、テレビのディスプレイなどでは、1画素の中でRGBの光の混色で色を表す仕組みであるため配列の横軸方向に色チャネルを並べ、次点で高さ、幅という順番の3次元配列になったのだと思われる。
補足:
※ 操作時には、最も外側の次元から順に取り出していくと覚えると良い。
参考