one_hot = torch.nn.functional.one_hot(torch.tensor([2, 0, 1]), num_classes=4)
one_hot
# output:
# tensor([[0, 0, 1, 0],
# [1, 0, 0, 0],
# [0, 1, 0, 0]])
戻すには、以下。
torch.argmax(one_hot, dim=1)
# output:
# tensor([2, 0, 1])
なぜか検索に引っ掛かりにくいので・・・。