やりたいこと
Windows環境で、こちらのプログラムを使ってLoRAを学習しようとしたところ、pickle関連のエラーが出てしまいました。
Linuxでは動いていたので、どうもWindowsだけだと思い調べると、Windowsではマルチプロセスのモジュールとpickleのモジュールがうまく連携できないようでした。
動かしたかった学習コマンド↓
accelerate launch ^
--num_processes=1 ^
--num_machines=1 ^
--mixed_precision="fp16" ^
--dynamo_backend="no" ^
train_lora_dreambooth.py ^
--pretrained_model_name_or_path="MODEL_NAME" ^
--instance_prompt="PROMPT" ^
--instance_data_dir="./datasets/xxx" ^
--output_dir="./output" ^
--resolution=512 ^
--train_batch_size=1 ^
--learning_rate=1e-4 ^
--learning_rate_text=5e-5 ^
--lr_scheduler="constant" ^
--lr_warmup_steps=0 ^
--gradient_accumulation_steps=1 ^
--max_train_steps=3000 ^
--train_text_encoder ^
--center_crop
発生したエラー(抜粋)
File "D:\ProgramFiles\WPy64-39100\python-3.9.10.amd64\lib\multiprocessing\context.py", line 327, in _PEOFErroropen
: Ran out of input
return Popen(process_obj)
File "D:\ProgramFiles\WPy64-39100\python-3.9.10.amd64\lib\multiprocessing\popen_spawn_win32.py", line 93, in __init__
reduction.dump(process_obj, to_child)
File "D:\ProgramFiles\WPy64-39100\python-3.9.10.amd64\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'main.<locals>.collate_fn'
Can't pickle local object ~~
とあるので、pickle関連のエラーと分かりました。
試行錯誤したこと
調べると、pickle
の互換のライブラリとしてdill
を使えばよいとの情報があり、pip install dill
でインストールしたあと、エラーで出ていたpython-3.9.10.amd64\lib\multiprocessing\reduction.py
のライブラリインポート部分を直接変更してみました。
# import pickle -> import dill as pickle
L15: import dill as pickle
しかし、dill
そのものもmultiprocessing.reduction
にある一部の関数を読み込んでいるらしく、循環インポートに陥ってしまい、またもエラーでした。
確かに、~~\Lib\site-packages\dill\_dill.py
の中身を見ると、
L160: from multiprocessing.reduction import _reduce_socket as reduce_socket
とあり、dill自体もmultiprocessing.reduction
を呼び出しています。
呼び出しているといっても一部の関数だけなので、この部分を別のモジュールとして分離しておけばよいという方法にたどり着きました。
最終的な対処方法
-
~~\lib\multiprocessing\reduction.py
を同じディレクトリにコピーし、reduction_original.py
とする。 -
reduction.py
の中身のみ、15行目をimport dill as pickle
に変更する。 -
~~\Lib\site-packages\dill\_dill.py
の160行目を、from multiprocessing.reduction_original import _reduce_socket as reduce_socket
に変える。
こうすることで、dillとmultiprocessing.reductionの循環インポートが解消され、無事にWindowsでLoRAを学習することができました!