0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TensorFlowの学習(tf.keras.Model.fit)を任意のタイミングで抜ける方法

Posted at

#1.この記事の内容

tf.keras.Model.fitで学習中のプロセスに対して,別プロセスから制御コマンドを送信することによって,任意のタイミングで関数から抜ける(学習を中断する)方法を記載します.

サンプルプログラムは筆者のGitHubに公開しています.

##1-1.実装方法の要約

  • 学習中のプロセスへの制御コマンド送信はFIFOを使用する
  • 制御コマンドは任意に定義する
    例) FIFOに文字列stopが書き込まれると学習を中断する
  • 学習中のプロセスではコールバックを用いて定期的(バッチ単位,EPOCH単位など)にFIFOから制御コマンドを読み込む
  • 学習の中断には,コールバック内でself.model属性を使用し,self.model.stop_training = Trueを設定する
  • FIFOから制御コマンドを読み込む際はos.openos.O_NONBLOCKフラグを指定してノンブロッキングで読み込む

##1-2.動作確認環境

  • Windows 11
  • WSL2 (Ubuntu 20.04)
  • tensorflow 2.4.0+nv
  • GPU: NVIDIA GeForce RTX 2070 SUPER

#2.背景

tf.keras.Model.fitで学習を開始すると,EarlyStoppingや指定EPOCH数,あるいはコールバックを用いてlossの収束状態を監視するなど,学習状態をトリガに学習を終了・中断する実装が多いです.

外部プロセスからfit()関数を抜ける形で学習を中断する方法が見つからず,学習中に外部から制御を受け取る手段を模索した末,FIFOを介したプロセス間通信を採用することとしました.

#3.実装方法解説

FIFOを介して学習を中断する方法の概略図(実装例)は下記の通りで,TensorFlowの学習はFIFO生成後に実行します.

fit_stop.png

###①FIFOを生成する

FIFO生成にはmkfifoコマンドを使用します.

$ mkfifo <fifo name>

WSLを使用する場合は,FIFOの生成先を/tmp等のLinux上に設定する必要があります.
WindowsではOSの機能としてFIFOがない為,Windows上に作成しようとするとOperation not supportedエラーが発生します.

$ cd /mnt/c/Users/<user name>
$ mkfifo ./test_fifo
mkfifo: cannot create fifo './test_fifo': Operation not supported

###②echoコマンド等でFIFOへコマンドを書き込む

学習を実行するプロセスとは異なるプロセスからFIFOへ学習中断コマンド(文字列)を書き込みます.

####FIFOへの制御コマンド書き込み例

文字列stopで学習を停止するものとしてFIFOへの書き込み例を紹介します.

#####bash

$ echo test > <fifo name>

#####Python

with open(<fifo name>, 'w') as f:
    f.write('stop\n')

#####C言語

f = fopen(<fifo name>, 'w');
fprintf(f, "stop\n");
fclose(f);

###③FIFOからノンブロッキングでコマンドを取得

TensorFlowのコールバックは下記のように記述します.
下記の例ではコンストラクタ__init__の引数でFIFOファイル名を渡して,バッチ単位でFIFOの状態を確認するものとしています.
FIFOはos.openos.O_NONBLOCKを指定することでノンブロッキングで開くことができます.

FIFOからデータを読み出す際のos.readはハンドラをブロックする為,os.O_NONBLOCKをセットしたままでは例外BlockingIOErrorが発生します.
os.open でハンドラを取得した後はノンブロッキングにしておく必要はない為,os.O_NONBLOCKをクリアします.

TensorFlowの学習中断はself.model.stop_training = Trueとすることで実現できます.

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, fifo):
        super().__init__()
        self.fifo = fifo
        
    def on_train_batch_end(self, batch, logs=None):
        fd = os.open(self.fifo, os.O_RDONLY | os.O_NONBLOCK)
        flags = fcntl.fcntl(fd, fcntl.F_GETFL)
        flags &= ~os.O_NONBLOCK
        fcntl.fcntl(fd, fcntl.F_SETFL, flags)
        
        try:
            command = os.read(fd, 128)
            command = command.decode()[:-1]
            while (True):
                buf = os.read(fd, 65536)
                if not buf:
                    break
        finally:
            os.close(fd)
        
        if (command):
            if (command == 'stop'):
                print("End batch: recv command={}".format(command))
                self.model.stop_training = True
            else:
                print("End batch: recv unknown command={}".format(command))

#4.さいごに

Ctrl+C やプロセスkillによる強制中断以外の方法で,任意のタイミングで学習を終了させる方法が見つからず,(他にスマートな実現方法はありそうですが)本記事のように実装することとしました.
任意のタイミングで学習を中断するお作法があれば,コメントなどでご教示いただけますと幸いです.

#5.関連リンク

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?