PyTorchでTensorDatasetを作成するときのメモ。
numpyからTensorDataset
1. インポート
import numpy as np
import torch.utils.data
2. 入力データと教師データを準備
例として入力データと教師データをランダムに生成。
# 入力データ
x = np.random.randn(100, 5)
# 教師データ
t = np.random.randint(2, size=10)
3. torch.tensor型に変換
x = torch.tensor(x, dtype=torch.float32)
t = torch.tensor(t, dtype=torch.int64)
4. torch.utils.data.TensorDatasetを生成
dataset = torch.utils.data.TensorDataset(x, t)
確認
# 1つ目のデータを確認
print('入力データ:', dataset[0][0])
print('教師データ:', dataset[0][1])