1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

torch.nn.functional.one_hot で "RuntimeError: one_hot is only applicable to index tensor" したときの対処法

Last updated at Posted at 2024-02-13

はじめに

torch.nn.functional.one_hotRuntimeError: one_hot is only applicable to index tensor したときの対処法です.
原因がわからなすぎて頭おかしくなりそうでした...

環境

  • Python 3.11
  • torch==2.1.2

具体例

import numpy as np
import torch
import torch.nn.functional as F

arr = np.array([0,1,2,3,0,1,2,3,0,1,2,3])
arr_T = torch.tensor(arr, dtype=torch.int32)
T = F.one_hot(arr_T)

# RuntimeError: one_hot is only applicable to index tensor.

一見問題なさそうですが,エラーを吐きます.

原因と対処法

torch.nn.functional.one_hot が対応している型が torch.int64 型のみなのが原因.したがって,以下のようにコードを書き換えればよい.
(そういう大事なことはドキュメントに書いてくれ...)

import numpy as np
import torch
import torch.nn.functional as F

arr = np.array([0,1,2,3,0,1,2,3,0,1,2,3])
arr_T = torch.tensor(arr, dtype=torch.int64)
T = F.one_hot(arr_T)

追記

one_hot が対応している型を明示してくれと言いましたが,実は下記のように明示してくれてました.
LongTensor というのが torch.int64 とほとんど同じ意味みたいです.

実際,以下のようなプログラムで確認すると,いずれも同じ型であることが分かります.

arr1 = torch.tensor([[1,2],[3,4]])
arr2 = torch.LongTensor([[1,2],[3,4]])
print("arr1 :",arr1.dtype)
print("arr2 :",arr2.dtype)

# arr1 : torch.int64
# arr2 : torch.int64

参考文献

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?