0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

実務でよく使うPyTorchの機能まとめ⑤|torch.nn.functional.one_hot

Last updated at Posted at 2025-02-14

torch.nn.functional.one_hot

基本的な使い方

torch.nn.functional.one_hotを用いることでラベルから1-hotベクトルを作成することができます。torch.nn.functional.one_hotはたとえば下記のように用いることができます。

import torch
import torch.nn.functional as F

x1 = torch.arange(0, 6)
x2 = torch.arange(0, 6) % 3
x2_one-hot = F.one_hot(x2)

print(x1)
print(x2)
print(x2_one-hot)

・実行結果

tensor([0, 1, 2, 3, 4, 5])
tensor([0, 1, 2, 0, 1, 2])
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]])

num_classes

DeepLearningのDataLoaderを用いてバッチを生成する場合など、バッチ単位でone_hotの挙動が変わる場合があります。このような際に用いると良いのがnum_classesです。

import torch
import torch.nn.functional as F

print(F.one_hot(torch.arange(0, 6) % 3, num_classes=5))
print(F.one_hot(torch.arange(0, 6)))
print(F.one_hot(torch.arange(0, 6), num_classes=10))

・実行結果

tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0]])
tensor([[1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1]])
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]])

上記の実行例はどれも「one_hot関数に与える第一引数の最大値+1」以上の値をnum_classesに与えていることに注意しておくと良いです。たとえば下記のようなコードはエラーになります。

print(F.one_hot(torch.arange(1, 6), num_classes=5))
    print(F.one_hot(torch.arange(1, 6), num_classes=5))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Class values must be smaller than num_classes.

要素の型

one_hot関数の引数に与えるテンソルのそれぞれの要素はfloat型ではなくint型である必要があります。たとえば下記のように要素がfloat型であるテンソルを引数に与えるとエラーが出力されます。

import torch
import torch.nn.functional as F

x = torch.arange(0., 6.)

print(x.dtype)
print(F.one_hot(x))

・実行結果

torch.float32

...

    print(F.one_hot(x))
          ^^^^^^^^^^^^
RuntimeError: one_hot is only applicable to index tensor of type LongTensor.

上記のようなエラーが出る場合は下記のようにlong()メソッドを用いることでone_hot関数を実行できるようになります。

import torch
import torch.nn.functional as F

x = torch.arange(0., 6.)

print(x.dtype)
x = x.long()
print(x.dtype)
print(F.one_hot(x))
print(F.one_hot(x).dtype)

・実行結果

torch.float32
torch.int64
tensor([[1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1]])
torch.int64
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?