PytrochでVOCのマスク画像をonehotにしなければならなかったのですが、少し調べるのに時間がかかったのでここに記します。
for文を回せばできることはできますが、
tensorのままスマートにやりたかったので探しました。
onehot化する
def onehot(image_tensor, n_clsses):
h, w = image_tensor.size()
onehot = torch.LongTensor(n_classes, h, w).zero_()
image_tensor = image_tensor.unsqueeze_(0)
onehot = onehot.scatter_(0, image_tensor, 1)
return onehot
簡単にコードの解説をします。
最初にtorch.LongTensor(n_classes, h, w).zero_()
で
縦横のサイズはおなじで奥行きがクラス数の、0で埋められたTensorを作ります。
次にimage_tensor = image_tensor.unsqueeze_(0)
で先程作ったゼロTensorにサイズをあわせます。
ゼロTensorのサイズは(n_classes, h, w)なので、
入力する画像をサイズ(h, w)の0からクラス数までの整数をもったTensorとすると、
unsqueeze_(0)
で(1, h, w)としています。
最後にonehot.scatter_(0, image_tensor, 1)
でondhotにすることができます。
引数はscatter_(dim, index, src)となっており、
image_tensorをインデックスとして0を1に変換するということになります。
使ってみる
実際に使ってみましょう
import torch
import torch.nn as nn
import torch.nn.functional as F
n_classes = 4
w, h = 8, 8
image_tensor = torch.randint(0, n_classes, (w, h))
print(image_tensor)
# tensor([[0, 1, 3, 3, 3, 0, 3, 0],
# [1, 3, 2, 1, 2, 3, 1, 3],
# [1, 0, 2, 0, 1, 3, 3, 1],
# [1, 0, 0, 1, 3, 1, 2, 0],
# [0, 0, 0, 3, 3, 3, 1, 3],
# [2, 3, 2, 3, 0, 1, 2, 0],
# [1, 1, 2, 0, 1, 3, 0, 1],
# [2, 0, 3, 0, 1, 1, 1, 0]])
def onehot(image_tensor, n_clsses):
h, w = image_tensor.size()
onehot = torch.LongTensor(n_classes, h, w).zero_()
image_tensor = image_tensor.unsqueeze_(0)
onehot = onehot.scatter_(0, image_tensor, 1)
return onehot
onehot_tensor = onehot(image_tensor, n_classes)
print(onehot_tensor)
# tensor([[[1, 0, 0, 0, 0, 1, 0, 1],
# [0, 0, 0, 0, 0, 0, 0, 0],
# [0, 1, 0, 1, 0, 0, 0, 0],
# [0, 1, 1, 0, 0, 0, 0, 1],
# [1, 1, 1, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 1, 0, 0, 1],
# [0, 0, 0, 1, 0, 0, 1, 0],
# [0, 1, 0, 1, 0, 0, 0, 1]],
# [[0, 1, 0, 0, 0, 0, 0, 0],
# [1, 0, 0, 1, 0, 0, 1, 0],
# [1, 0, 0, 0, 1, 0, 0, 1],
# [1, 0, 0, 1, 0, 1, 0, 0],
# [0, 0, 0, 0, 0, 0, 1, 0],
# [0, 0, 0, 0, 0, 1, 0, 0],
# [1, 1, 0, 0, 1, 0, 0, 1],
# [0, 0, 0, 0, 1, 1, 1, 0]],
# [[0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 1, 0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 1, 0],
# [0, 0, 0, 0, 0, 0, 0, 0],
# [1, 0, 1, 0, 0, 0, 1, 0],
# [0, 0, 1, 0, 0, 0, 0, 0],
# [1, 0, 0, 0, 0, 0, 0, 0]],
# [[0, 0, 1, 1, 1, 0, 1, 0],
# [0, 1, 0, 0, 0, 1, 0, 1],
# [0, 0, 0, 0, 0, 1, 1, 0],
# [0, 0, 0, 0, 1, 0, 0, 0],
# [0, 0, 0, 1, 1, 1, 0, 1],
# [0, 1, 0, 1, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 1, 0, 0],
# [0, 0, 1, 0, 0, 0, 0, 0]]])
ちゃんとサイズを保ったままonehotになっているのがわかると思います。
追記
より少ない行数でできる方法を@koshian2さんに教えていただきました。
img_tensor=torch.randint(0, 4, (8,8))
torch.eye(4)[img_tensor].permute([2,0,1])
Tensorはインデックス指定もできるんですね!