LoginSignup
18
18

More than 3 years have passed since last update.

PytorchのTensorをonehotにする

Last updated at Posted at 2019-09-18

PytrochでVOCのマスク画像をonehotにしなければならなかったのですが、少し調べるのに時間がかかったのでここに記します。

for文を回せばできることはできますが、
tensorのままスマートにやりたかったので探しました。

onehot化する

def onehot(image_tensor, n_clsses):
    h, w = image_tensor.size()
    onehot = torch.LongTensor(n_classes, h, w).zero_()
    image_tensor = image_tensor.unsqueeze_(0)
    onehot = onehot.scatter_(0, image_tensor, 1)
    return onehot

簡単にコードの解説をします。

最初にtorch.LongTensor(n_classes, h, w).zero_()
縦横のサイズはおなじで奥行きがクラス数の、0で埋められたTensorを作ります。

次にimage_tensor = image_tensor.unsqueeze_(0)で先程作ったゼロTensorにサイズをあわせます。
ゼロTensorのサイズは(n_classes, h, w)なので、
入力する画像をサイズ(h, w)の0からクラス数までの整数をもったTensorとすると、
unsqueeze_(0)(1, h, w)としています。

最後にonehot.scatter_(0, image_tensor, 1)でondhotにすることができます。
引数はscatter_(dim, index, src)となっており、
image_tensorをインデックスとして0を1に変換するということになります。

使ってみる

実際に使ってみましょう

import torch
import torch.nn as nn
import torch.nn.functional as F

n_classes = 4
w, h = 8, 8
image_tensor = torch.randint(0, n_classes, (w, h))
print(image_tensor)
# tensor([[0, 1, 3, 3, 3, 0, 3, 0],
#         [1, 3, 2, 1, 2, 3, 1, 3],
#         [1, 0, 2, 0, 1, 3, 3, 1],
#         [1, 0, 0, 1, 3, 1, 2, 0],
#         [0, 0, 0, 3, 3, 3, 1, 3],
#         [2, 3, 2, 3, 0, 1, 2, 0],
#         [1, 1, 2, 0, 1, 3, 0, 1],
#         [2, 0, 3, 0, 1, 1, 1, 0]])


def onehot(image_tensor, n_clsses):
    h, w = image_tensor.size()
    onehot = torch.LongTensor(n_classes, h, w).zero_()
    image_tensor = image_tensor.unsqueeze_(0)
    onehot = onehot.scatter_(0, image_tensor, 1)
    return onehot

onehot_tensor = onehot(image_tensor, n_classes)
print(onehot_tensor)
# tensor([[[1, 0, 0, 0, 0, 1, 0, 1],
#          [0, 0, 0, 0, 0, 0, 0, 0],
#          [0, 1, 0, 1, 0, 0, 0, 0],
#          [0, 1, 1, 0, 0, 0, 0, 1],
#          [1, 1, 1, 0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 1, 0, 0, 1],
#          [0, 0, 0, 1, 0, 0, 1, 0],
#          [0, 1, 0, 1, 0, 0, 0, 1]],

#         [[0, 1, 0, 0, 0, 0, 0, 0],
#          [1, 0, 0, 1, 0, 0, 1, 0],
#          [1, 0, 0, 0, 1, 0, 0, 1],
#          [1, 0, 0, 1, 0, 1, 0, 0],
#          [0, 0, 0, 0, 0, 0, 1, 0],
#          [0, 0, 0, 0, 0, 1, 0, 0],
#          [1, 1, 0, 0, 1, 0, 0, 1],
#          [0, 0, 0, 0, 1, 1, 1, 0]],

#         [[0, 0, 0, 0, 0, 0, 0, 0],
#          [0, 0, 1, 0, 1, 0, 0, 0],
#          [0, 0, 1, 0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0, 0, 1, 0],
#          [0, 0, 0, 0, 0, 0, 0, 0],
#          [1, 0, 1, 0, 0, 0, 1, 0],
#          [0, 0, 1, 0, 0, 0, 0, 0],
#          [1, 0, 0, 0, 0, 0, 0, 0]],

#         [[0, 0, 1, 1, 1, 0, 1, 0],
#          [0, 1, 0, 0, 0, 1, 0, 1],
#          [0, 0, 0, 0, 0, 1, 1, 0],
#          [0, 0, 0, 0, 1, 0, 0, 0],
#          [0, 0, 0, 1, 1, 1, 0, 1],
#          [0, 1, 0, 1, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0, 1, 0, 0],
#          [0, 0, 1, 0, 0, 0, 0, 0]]])

ちゃんとサイズを保ったままonehotになっているのがわかると思います。

追記

より少ない行数でできる方法を@koshian2さんに教えていただきました。

img_tensor=torch.randint(0, 4, (8,8))
torch.eye(4)[img_tensor].permute([2,0,1])

Tensorはインデックス指定もできるんですね!

18
18
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
18