はじめに
torch.nn.functional.one_hot
で RuntimeError: one_hot is only applicable to index tensor
したときの対処法です.
原因がわからなすぎて頭おかしくなりそうでした...
環境
- Python 3.11
- torch==2.1.2
具体例
import numpy as np
import torch
import torch.nn.functional as F
arr = np.array([0,1,2,3,0,1,2,3,0,1,2,3])
arr_T = torch.tensor(arr, dtype=torch.int32)
T = F.one_hot(arr_T)
# RuntimeError: one_hot is only applicable to index tensor.
一見問題なさそうですが,エラーを吐きます.
原因と対処法
torch.nn.functional.one_hot
が対応している型が torch.int64
型のみなのが原因.したがって,以下のようにコードを書き換えればよい.
(そういう大事なことはドキュメントに書いてくれ...)
import numpy as np
import torch
import torch.nn.functional as F
arr = np.array([0,1,2,3,0,1,2,3,0,1,2,3])
arr_T = torch.tensor(arr, dtype=torch.int64)
T = F.one_hot(arr_T)
追記
one_hot
が対応している型を明示してくれと言いましたが,実は下記のように明示してくれてました.
LongTensor
というのが torch.int64
とほとんど同じ意味みたいです.
実際,以下のようなプログラムで確認すると,いずれも同じ型であることが分かります.
arr1 = torch.tensor([[1,2],[3,4]])
arr2 = torch.LongTensor([[1,2],[3,4]])
print("arr1 :",arr1.dtype)
print("arr2 :",arr2.dtype)
# arr1 : torch.int64
# arr2 : torch.int64
参考文献
- 公式ドキュメント: https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html
- github のissue: https://github.com/pytorch/pytorch/issues/86162
-
LongTensor
はここを参照 https://pytorch.org/docs/stable/tensors.html