#まえがき
Pytorchの嬉しいところは,HDDからデータをメモリに並列に読み込む仕組みが標準で実装されているところだと思う(Kerasでは自前で作成する必要があってだるい).
実現のためには,torch.utils.data.Dataset
とtorch.utils.data.Dataloader
を使用する.それらを使用する際に一か所つまづいた.
実行環境は以下.
- OS: Windows10 (64bit, 1903)
- pytorch: 1.1
TL; DR
Windowsで明示的・暗黙的に並列処理するときには,並列処理する対象のクラスを別のファイルに定義すること.
つまり,DataLoader
のインスタンス化とDataset
クラスの定義を同じ.py
で行ってはいけない.
詳細に
以下のように書くと,Broken Pipe
だのなんだのよくわからんエクセプションが吐かれて泣きたくなる.
(ちなみにこのコードは切ったり貼ったり名前変えたりしたので,正確ではないです.)
main.py
if __name__ == '__main__':
... do something
import numpy as np
class MyDataset(torch.utils.data.Dataset):
def __init__(self, csv_path):
data = np.loadtxt(csv_path, delimiter="\t")
self.vector = data[:, 0:9]
self.target = data[:, 10]
def __len__(self):
return len(self.target)
def __getitem__(self, idx):
return ( self.vector[idx], self.target[idx] )
train_loader = torch.utils.data.DataLoader(
MyDataset(path),
batch_size=batch_size,
shuffle=True,
num_workers=7,
pin_memory=True
)
動作させるためには,MyDataset
の定義を別のファイルに移動させればよい.
mydataset.py
import numpy as np
class MyDataset(torch.utils.data.Dataset):
def __init__(self, csv_path):
data = np.loadtxt(csv_path, delimiter="\t")
self.vector = data[:, 0:9]
self.target = data[:, 10]
def __len__(self):
return len(self.target)
def __getitem__(self, idx):
return ( self.vector[idx], self.target[idx] )
main.py
if __name__ == '__main__':
... do something
train_loader = torch.utils.data.DataLoader(
MyDataset(path),
batch_size=batch_size,
shuffle=True,
num_workers=7,
pin_memory=True
)
あとがき
たぶんLinuxマシンでは直面しない問題なので,細部でフラストレーション感じたくなければUbuntuとか使っとくのがいいんだと思います.
僕はWindows大正義だと思っているのでもう少し苦しもうと思います.