LoginSignup
11
7

More than 5 years have passed since last update.

Create Pytorch DataLoader from numpy.array

Last updated at Posted at 2019-01-12

scikit-learnのデータセット(ndarray) からPyTorchのDataLoaderを作るのにすこし躓いた.

今後のためにメモ

# データ作成
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split

mnist = fetch_mldata("MNIST original")
X = mnist.data.astype(np.float32)  # shape(70000, 784)
y = mnist.target.astype(np.int64)  # shape(70000)


# ndarrayからインスタンスを一つずつ取り出してtorch.Tensorに変換してリストに入れる
tensor_X = torch.stack([torch.from_numpy(np.array(i)) for i in X])
tensor_y = torch.stack([torch.from_numpy(np.array(i)) for i in y])

# trainとtestで分ける
train_size = 60000
X_train = tensor_X[:train_size]
y_train = tensor_y[:train_size]
X_test = tensor_X[train_size:]
y_test = tensor_y[train_size:]

# DataLoaderを作る
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_loader)

test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test_dataset)
X = mnist.data.astype(np.float32)
y = mnist.target.astype(np.int64)

ここで型を変換してる理由は、PyTorchの要求してくる型に合わせるためです。

11
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
7