概要
この記事ではTensorFlowを用いた、モデルの出力計算の方法を中心として、model.fitに関して書きます。数行で命令が完了するので楽ですが、何をどこでやってるのかわからなかったので、ライブラリを書き換えるときはすごく大変でした。その時に学んだこと(主に出力の計算方法、printデバッグの方法)を書いていきます。
よくわかってないことも多数あるので気を付けてください。
また、私の環境はPython3.9.13、TensorFlow2.10.0です。それ以外の環境は全然知らないので注意してください。また、参考にした文献があるわけでもないので、この記事を参考にするときは必ず自分でも確認してください。
出力の計算方法について
model.fitとは
モデルの学習を進めるときに用いるもので、例えば以下のように使います。
model.fit(trial_iterator(X_train, Y_train,
batch_size=bs_EEGSym, shuffle=True,
augmentation=augmentation),
steps_per_epoch=X_train.shape[0] / bs_EEGSym,
epochs=500, validation_data=(X_validate, Y_validate),
callbacks=[early_stopping])
本来ライブラリを使わなければ、各層のパラメータの管理や、出力の計算、損失関数の計算などをすべてやる必要があり、とても大変です。しかし、これを使うことによってそれらを一切考えず、楽に処理をすることができます。
モデルのトレーニングはどこで始まる?
model.fitによって、training.pyというファイルの関数fitが呼び出されます。そしてfunctional.pyの1555行目には以下のプログラムがあり、ここからモデルのトレーニングが始まります。
for step in data_handler.steps():
with tf.profiler.experimental.Trace(
"train",
epoch_num=epoch,
step_num=step,
batch_size=batch_size,
_r=1,
):
callbacks.on_train_batch_begin(step)
tmp_logs = self.train_function(iterator)
1ステップごとにself.train_functionが呼び出されて学習が進められます。モデルを学習するとき、以下のような動く矢印みたいなのが実行画面に現れると思います。
Epoch 103/500
123/123 [==============================] - 7s 58ms/step - loss: 0.5143 - accuracy: 0.7515 - val_loss: 0.3744 - val_accuracy: 0.8000
Epoch 104/500
79/123 [==================>...........] - ETA: 2s - loss: 0.5355 - accuracy: 0.7421
それの左にある数字が1増えるたびにtrain_functionは呼び出されます。(呼び出されるたび1増えるという方が正しいかも)
iteratorにはトレーニングデータやラベルなどが入っていて、それらをself.train_functionで使って学習を進めます。
train_functionとは?
その名の通りトレーニングを実行するための関数で、make_train_functionという関数で生成されます。基本的には、関数train_stepが呼び出される(多分)ことになると思います。その一部を抜粋すると以下のようになっています。
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(x, y, y_pred, sample_weight)
トレーニングデータxに対して予測値y_predを求め、それらに対して損失を計算するものとなっています。このプログラムの4行目でいよいよ入力に対するモデルの出力が計算されます。ここでのselfはクラスFunctionalであり、functional.pyのcall関数が呼び出されます。
出力の計算はfunctional.pyを中心に行われる
functional.pyのcall関数は**_run_internal_graph**という関数を呼び出すためのものです。出力を計算するための主要なプログラムは以下の部分であり、
for depth in depth_keys:
nodes = nodes_by_depth[depth]
for node in nodes:
if node.is_input:
continue # Input tensors already exist.
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
continue # Node is not computable, try skipping.
args, kwargs = node.map_arguments(tensor_dict)
outputs = node.layer(*args, **kwargs)
# Update tensor_dict.
for x_id, y in zip(
node.flat_output_ids, tf.nest.flatten(outputs)
):
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
以上のプログラムのnodesには自分で定義したモデルの各層(例えば畳み込み層とか、平均プーリング層とか)が格納されています。これを入力層から順に取り出して、計算を進めていくものとなっています。
11行目では、*argsに保存された前の層の出力などをもとに、nodeが持つ層の出力を計算します。
各層の計算
各層の出力を求めるために、一度functional.pyと同ディレクトリ内base_layer.pyの__call__という関数が呼び出されます。続いて、その関数内で各層に対応するファイルのcall関数が呼び出され、そこで具体的な計算が行われます。
トレーニングに限らず、テストなどで用いるmodel.predictとかで出力を計算するときも、Functional.pyを中心に同じ処理が行われます。
ライブラリのデバッグについて
ライブラリの中身のデバッグをするときに私はすごく苦労しました。なぜならどこのファイルのどの関数が呼び出されてるのか全く分からなかったからです。ミスってただけかもしれないけど、ブレークポイントとかもよく使えなかったし。例えばtrain_stepで示した、以下のようなプログラムでは特に困りました。
y_pred = self(x, training=True)
急にselfとか言われても意味わかんないですよね。そこで使えるのがtf.printです。
tf.print("\n",self)
y_pred = self(x, training=True)
以上のようにtf.printを追加すると、
<keras.engine.functional.Functional object at 0x000002346D63C7C0>
1/18 [>.............................] - ETA: 40s - loss: 1.4693 - accuracy: 0.5625
<keras.engine.functional.Functional object at 0x000002346D63C7C0>
2/18 [==>...........................] - ETA: 10s - loss: 1.4093 - accuracy: 0.5000
<keras.engine.functional.Functional object at 0x000002346D63C7C0>
3/18 [===>..........................] - ETA: 9s - loss: 1.2918 - accuracy: 0.5104
このようにselfについてちょっとだけ詳しく表示してくれました。このおかげで、関数が何を指しているのかかなり理解しやすくなりました。
ただのprintではだめなの?って思う方へ向けて、先ほどのtf.printを普通のprintに変えたときの実行例を示します。
<keras.engine.functional.Functional object at 0x0000027862F2DDF0>
18/18 [==============================] - 16s 774ms/step - loss: 1.4685 - accuracy: 0.4850 - val_loss: 0.7234 - val_accuracy: 0.4600
Epoch 2/500
5/18 [=======>......................] - ETA: 9s - loss: 1.3248 - accuracy: 0.5625
selfの内容を一回だけしか表示してくれなくなってしまいました。別にいいじゃん。と思うかもしれませんが、前述したとおり、関数train_functionは1ステップごとに呼び出され、それが呼び出す関数train_stepも同じく1ステップごとに呼び出されないとおかしいです。以上のような実行結果になってしまうと、プログラムの理解に大きく支障をきたします。(私はここでかなり詰まりました)
また、トレーニングデータxを通常のprintで表示しようとすると、
Tensor("IteratorGetNext:0", shape=(None, None, None, None), dtype=float32)
このように中身を見せてもらえなくなってしまいます。(どっかで理由見かけた気がするんですけど忘れちゃいました)しかし、tf.printを使えば中身までしっかり見せてくれます。
以上のことから、私はtensorflowライブラリをprint文デバッグするときはtf.printを使うことをお勧めします。
今述べたことは確かに大事ですが、結局のところ気合が一番重要だったと思います。
最後に
今回は自分が卒業研究をやるにあたって、ライブラリを改造するために少しだけ理解したかもしれないことを書きました。卒論に乗せるような内容じゃないし、どうしようかと思ってたのでちょうどよかったです。
研究室ではpytorchが主流だし、ライブラリに興味ない人はこんな記事読まないだろうし、ライブラリに興味ある人にとっては常識だったかもしれないですが、自分が卒業研究をやる中でもかなり上位に食い込むレベルで頑張った事だと思うので書かせていただきました。ここまで読んでくださってありがとうございました。