はじめに
初投稿です.
PyTorch で自作の DataLoader を使った画像の読み込みがやけに不安定だったり,最悪の場合フリーズしてしまうということが度々起こったので,その原因と解決策を残しておきます.
原因
自作の Dataloader の中で OpenCV を使って画像読み込みをしていたのが原因でした.
どうやら OpenCV の並列処理は PyTorch の並列処理と相性が悪く,不具合を起こすことがあるみたいです.
解決方法
実際に効果があった方法を 3 つ紹介します.
1: cv2.imread() の代わりに,PILでの読み込みを使う
デフォルトのDataLoaderでも使われているのでこの方法が間違いないです.
torchvision の transform を使ったデータ拡張を行う際も PIL 形式を前提にしているのでこちらが推奨されると思います.
2: DataLoader の num_workers を0に設定する
そもそもDataLoaderでの読み込みを並列処理を無効にすれば解決します.
ただし,速度は落ちます.
3: OpenCV の並列処理を無効にする
import cv2 の後に以下の2行を追加します.
( ↑ これだけだと再びdead lockしてしまいました.)
import cv2 の前後に以下の行を追加します.
実装の都合上 ( & PILに書き換えるのが面倒だったため),私はこれを採用し,フリーズは起きなくなりました.
import os # ADD
os.environ["OMP_NUM_THREADS"] = "1" # ADD
os.environ["MKL_NUM_THREADS"] = "1" # ADD
import cv2
cv2.setNumThreads(0) # ADD
cv2.ocl.setUseOpenCL(False) # ADD
...
class MyDatasets(torch.utils.data.Dataset):
...
def __getitem__(self, idx, ...):
...
img = cv2.imread(paths[idx],-1)
...
おわりに
補足箇所がありましたら,ご教授いただけると幸いです.
参考