はじめに
epochを最初100回に設定して、学習をスタートしたときに、lossの状態から10回ぐらいで中身をのぞいてみて、また再開したいときなど、学習途中にストップさせる「ボタン」があると便利かもしれません。
以前に書いた
Deconvolution 2Dの学習経過をjupyter上で描画(Bokeh)
にスタートとストップ「ボタン」を追加してみた動画↓になります。
学習途中でストップさせる「ボタン」を作る(Jupyter) https://t.co/ArppKagoLi
— さまこば (@samacoba) 2018年8月2日
startボタンで学習をスタートさせて、途中10epochくらいで、stopボタンで停止させています。
途中ストップした状態で、Deconvolution2DのウェイトWをprintして確認してから再スタートさせています。
説明
ソースは
https://github.com/samacoba/Mytest/blob/master/No_03_ipy_stop_learning.ipynb
においてあります。
#トレーニング
def training():
global epoch
while epoch < 50:
epoch = epoch + 1
#1層のDeconvolutionを通してロスを計算しアップデート
model.cleargrads()
img_y = model(img_p)
loss = F.mean_squared_error(img_y, img_core)
loss.backward()
optimizer.update()
#画像・ロスデータをセット
rend1.data_source.data['image'] = [img_p[0][0]]
rend2.data_source.data['image'] = [img_y.data[0][0]]
plt1.title.text='epoch = '+str(epoch)
plt2.title.text='loss = '+str(loss.data)
push_notebook(handle = handle)#表示をアップデート
time.sleep(0.5)
if(stop_flag == True):#ストップフラグがTrueでトレーニング停止
break
トレーニング関数の中にて、stop_flagをepoch毎にチェックして、stop_flagがTrueになるとループを抜けるようにしてあります。
#ウィジェット
import ipywidgets as widgets
from IPython.display import display
start_button = widgets.Button(description="Start")#スタートボタンウィジェット
stop_button = widgets.Button(description="Stop")#ストップボタンウィジェット
#スタートボタンをクリック時の実行関数を定義
def on_start_button_clicked(b):
global stop_flag
if stop_flag == True:#2回押し防止用
stop_flag = False
#別スレッドでトレーニングを実行する必要がある
import threading
thread = threading.Thread(target = training)
thread.start()#トレーニングスタート
#ストップボタンをクリック時の実行関数を定義
def on_stop_button_clicked(b):
global stop_flag
stop_flag = True
#ボタンをクリック時の実行関数をウィジェットに結びつける
start_button.on_click(on_start_button_clicked)
stop_button.on_click(on_stop_button_clicked)
jupyter widgetsを使ってスタートとストップボタンを作っています。
ストップボタンを押すとstop_flagがTrueに変わるので、実行途中のepochが終わるとループから抜け出します。
training関数は別スレッドで行わないとうまくいきませんでした。
ちなみにボタンを作らなくてもtrainingを別スレッドにしておけば、直接セルでstop_flag = Trueを実行にて停止できます。