torch.nn.functional.one_hot
基本的な使い方
torch.nn.functional.one_hot
を用いることでラベルから1-hotベクトルを作成することができます。torch.nn.functional.one_hot
はたとえば下記のように用いることができます。
import torch
import torch.nn.functional as F
x1 = torch.arange(0, 6)
x2 = torch.arange(0, 6) % 3
x2_one-hot = F.one_hot(x2)
print(x1)
print(x2)
print(x2_one-hot)
・実行結果
tensor([0, 1, 2, 3, 4, 5])
tensor([0, 1, 2, 0, 1, 2])
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
num_classes
DeepLearningのDataLoader
を用いてバッチを生成する場合など、バッチ単位でone_hot
の挙動が変わる場合があります。このような際に用いると良いのがnum_classes
です。
import torch
import torch.nn.functional as F
print(F.one_hot(torch.arange(0, 6) % 3, num_classes=5))
print(F.one_hot(torch.arange(0, 6)))
print(F.one_hot(torch.arange(0, 6), num_classes=10))
・実行結果
tensor([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]])
tensor([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]])
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]])
上記の実行例はどれも「one_hot
関数に与える第一引数の最大値+1」以上の値をnum_classes
に与えていることに注意しておくと良いです。たとえば下記のようなコードはエラーになります。
print(F.one_hot(torch.arange(1, 6), num_classes=5))
print(F.one_hot(torch.arange(1, 6), num_classes=5))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Class values must be smaller than num_classes.
要素の型
one_hot
関数の引数に与えるテンソルのそれぞれの要素はfloat
型ではなくint
型である必要があります。たとえば下記のように要素がfloat
型であるテンソルを引数に与えるとエラーが出力されます。
import torch
import torch.nn.functional as F
x = torch.arange(0., 6.)
print(x.dtype)
print(F.one_hot(x))
・実行結果
torch.float32
...
print(F.one_hot(x))
^^^^^^^^^^^^
RuntimeError: one_hot is only applicable to index tensor of type LongTensor.
上記のようなエラーが出る場合は下記のようにlong()
メソッドを用いることでone_hot
関数を実行できるようになります。
import torch
import torch.nn.functional as F
x = torch.arange(0., 6.)
print(x.dtype)
x = x.long()
print(x.dtype)
print(F.one_hot(x))
print(F.one_hot(x).dtype)
・実行結果
torch.float32
torch.int64
tensor([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]])
torch.int64