#1.この記事の内容
tf.keras.Model.fitで学習中のプロセスに対して,別プロセスから制御コマンドを送信することによって,任意のタイミングで関数から抜ける(学習を中断する)方法を記載します.
サンプルプログラムは筆者のGitHubに公開しています.
##1-1.実装方法の要約
- 学習中のプロセスへの制御コマンド送信はFIFOを使用する
- 制御コマンドは任意に定義する
例) FIFOに文字列stop
が書き込まれると学習を中断する - 学習中のプロセスではコールバックを用いて定期的(バッチ単位,EPOCH単位など)にFIFOから制御コマンドを読み込む
- 学習の中断には,コールバック内でself.model属性を使用し,
self.model.stop_training = True
を設定する - FIFOから制御コマンドを読み込む際はos.openに
os.O_NONBLOCK
フラグを指定してノンブロッキングで読み込む
##1-2.動作確認環境
- Windows 11
- WSL2 (Ubuntu 20.04)
- tensorflow 2.4.0+nv
- GPU: NVIDIA GeForce RTX 2070 SUPER
#2.背景
tf.keras.Model.fitで学習を開始すると,EarlyStoppingや指定EPOCH数,あるいはコールバックを用いてlossの収束状態を監視するなど,学習状態をトリガに学習を終了・中断する実装が多いです.
外部プロセスからfit()
関数を抜ける形で学習を中断する方法が見つからず,学習中に外部から制御を受け取る手段を模索した末,FIFOを介したプロセス間通信を採用することとしました.
#3.実装方法解説
FIFOを介して学習を中断する方法の概略図(実装例)は下記の通りで,TensorFlowの学習はFIFO生成後に実行します.
###①FIFOを生成する
FIFO生成にはmkfifo
コマンドを使用します.
$ mkfifo <fifo name>
WSLを使用する場合は,FIFOの生成先を/tmp
等のLinux上に設定する必要があります.
WindowsではOSの機能としてFIFOがない為,Windows上に作成しようとするとOperation not supported
エラーが発生します.
$ cd /mnt/c/Users/<user name>
$ mkfifo ./test_fifo
mkfifo: cannot create fifo './test_fifo': Operation not supported
###②echoコマンド等でFIFOへコマンドを書き込む
学習を実行するプロセスとは異なるプロセスからFIFOへ学習中断コマンド(文字列)を書き込みます.
####FIFOへの制御コマンド書き込み例
文字列stop
で学習を停止するものとしてFIFOへの書き込み例を紹介します.
#####bash
$ echo test > <fifo name>
#####Python
with open(<fifo name>, 'w') as f:
f.write('stop\n')
#####C言語
f = fopen(<fifo name>, 'w');
fprintf(f, "stop\n");
fclose(f);
###③FIFOからノンブロッキングでコマンドを取得
TensorFlowのコールバックは下記のように記述します.
下記の例ではコンストラクタ__init__
の引数でFIFOファイル名を渡して,バッチ単位でFIFOの状態を確認するものとしています.
FIFOはos.open
にos.O_NONBLOCK
を指定することでノンブロッキングで開くことができます.
FIFOからデータを読み出す際のos.read
はハンドラをブロックする為,os.O_NONBLOCK
をセットしたままでは例外BlockingIOError
が発生します.
os.open
でハンドラを取得した後はノンブロッキングにしておく必要はない為,os.O_NONBLOCK
をクリアします.
TensorFlowの学習中断はself.model.stop_training = True
とすることで実現できます.
class CustomCallback(keras.callbacks.Callback):
def __init__(self, fifo):
super().__init__()
self.fifo = fifo
def on_train_batch_end(self, batch, logs=None):
fd = os.open(self.fifo, os.O_RDONLY | os.O_NONBLOCK)
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
flags &= ~os.O_NONBLOCK
fcntl.fcntl(fd, fcntl.F_SETFL, flags)
try:
command = os.read(fd, 128)
command = command.decode()[:-1]
while (True):
buf = os.read(fd, 65536)
if not buf:
break
finally:
os.close(fd)
if (command):
if (command == 'stop'):
print("End batch: recv command={}".format(command))
self.model.stop_training = True
else:
print("End batch: recv unknown command={}".format(command))
#4.さいごに
Ctrl+C
やプロセスkillによる強制中断以外の方法で,任意のタイミングで学習を終了させる方法が見つからず,(他にスマートな実現方法はありそうですが)本記事のように実装することとしました.
任意のタイミングで学習を中断するお作法があれば,コメントなどでご教示いただけますと幸いです.
#5.関連リンク