1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

言語処理100本ノック(2020)-76: チェックポイント(TensorFlow)

Last updated at Posted at 2021-10-17

言語処理100本ノック 2020 (Rev2)「第8章: ニューラルネット」76本目「チェックポイント」記録です。kearaコールバック関数を使って実装しています。あまり公式ガイドにこの方法書いておらず、本当にこれでいいのか少し不安。2年くらい前はこれでよかったのだけど、数年後にこのやり方変わっていそう。
記事「まとめ: 言語処理100本ノックで学べることと成果」言語処理100本ノック 2015についてはまとめていますが、追加で差分の言語処理100本ノック 2020 (Rev2)についても更新します。

参考リンク

リンク 備考
76_チェックポイント.ipynb 回答プログラムのGitHubリンク
言語処理100本ノック 2020 第8章: ニューラルネット (PyTorchだけど)解き方の参考
【言語処理100本ノック 2020】第8章: ニューラルネット (PyTorchだけど)解き方の参考
まとめ: 言語処理100本ノックで学べることと成果 言語処理100本ノックまとめ記事
公式チュートリアル: チェックポイントの保存を移行する 移行ではないが公式なので記載
公式ガイド: Kerasモデルの保存と読み込み TF Checkpoint方式。この方式使ってない。

環境

後々GPUを使わないと厳しいので、Goolge Colaboratory使いました。Pythonやそのパッケージでより新しいバージョンありますが、新機能使っていないので、プリインストールされているものをそのまま使っています。

種類 バージョン 内容
Python 3.7.12 Google Colaboratoryのバージョン
google 2.0.3 Google Driveのマウントに使用
tensorflow 2.6.0 ディープラーニングの主要処理

第8章: ニューラルネット

学習内容

深層学習フレームワークの使い方を学び,ニューラルネットワークに基づくカテゴリ分類を実装します.

ノック内容

第6章で取り組んだニュース記事のカテゴリ分類を題材として,ニューラルネットワークでカテゴリ分類モデルを実装する.なお,この章ではPyTorch, TensorFlow, Chainerなどの機械学習プラットフォームを活用せよ.

76. チェックポイント

問題75のコードを改変し,各エポックのパラメータ更新が完了するたびに,チェックポイント(学習途中のパラメータ(重み行列など)の値や最適化アルゴリズムの内部状態)をファイルに書き出せ.

回答

回答結果

学習時にファイル書き込みしたことを知らせてくれています。今回は面倒なので3エポックで学習しています。

結果(学習時)
Epoch 1/3
334/334 [==============================] - 3s 6ms/step - loss: 1.2526 - acc: 0.6938 - val_loss: 1.1426 - val_acc: 0.7844
INFO:tensorflow:Assets written to: ./tmp/01-1.14/assets
Epoch 2/3
334/334 [==============================] - 2s 5ms/step - loss: 1.0725 - acc: 0.7785 - val_loss: 1.0120 - val_acc: 0.7867
INFO:tensorflow:Assets written to: ./tmp/02-1.01/assets
Epoch 3/3
334/334 [==============================] - 2s 5ms/step - loss: 0.9695 - acc: 0.7796 - val_loss: 0.9305 - val_acc: 0.7859
INFO:tensorflow:Assets written to: ./tmp/03-0.93/assets
CPU times: user 7.47 s, sys: 1.56 s, total: 9.03 s
Wall time: 7.68 s
<keras.callbacks.History at 0x7f8c8044c8d0>

学習語にlsコマンドでチェックポイントファイルを見ています。フォルダ名はエポック名とValidation Lossの値で作っています。

結果(学習後ファイル確認)
tmp:
01-1.14/  02-1.01/  03-0.93/

tmp/01-1.14:
assets/  keras_metadata.pb  saved_model.pb  variables/

tmp/01-1.14/assets:

tmp/01-1.14/variables:
variables.data-00000-of-00001  variables.index

tmp/02-1.01:
assets/  keras_metadata.pb  saved_model.pb  variables/

tmp/02-1.01/assets:

tmp/02-1.01/variables:
variables.data-00000-of-00001  variables.index

tmp/03-0.93:
assets/  keras_metadata.pb  saved_model.pb  variables/

tmp/03-0.93/assets:

tmp/03-0.93/variables:
variables.data-00000-of-00001  variables.index

回答プログラム 76_チェックポイント.ipynb

GitHubには確認用コードも含めていますが、ここには必要なものだけ載せています。

import tensorflow as tf
from google.colab import drive

drive.mount('/content/drive')
LOG_DIR = './logs'

%load_ext tensorboard

def _parse_function(example_proto):
    # 特徴の記述
    feature_description = {
        'title': tf.io.FixedLenFeature([], tf.string),
        'category': tf.io.FixedLenFeature([], tf.string)}
  
  # 上記の記述を使って入力の tf.Example を処理
    features = tf.io.parse_single_example(example_proto, feature_description)
    X = tf.io.decode_raw(features['title'], tf.float32)
    y = tf.io.decode_raw(features['category'], tf.int32)
    return X, y

BASE_PATH = '/content/drive/MyDrive/ColabNotebooks/ML/NLP100_2020/08.NeuralNetworks/'

def get_dataset(file_name):
    ds_raw = tf.data.TFRecordDataset(BASE_PATH+file_name+'.tfrecord')

    #shuffleはここを見て理解。データ件数取る方法がわからず、1000件に設定
    #https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e
    return ds_raw.map(_parse_function).shuffle(1000).batch(32)

train_ds = get_dataset('train')
valid_ds = get_dataset('valid')

model = tf.keras.Sequential(
    [tf.keras.layers.Dense(
        4, activation='softmax', input_dim=300, use_bias=False, kernel_initializer='random_uniform') ])
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['acc'])
model.summary()

%tensorboard --logdir $LOG_DIR

callbacks = []
callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR))
callbacks.append(tf.keras.callbacks.ModelCheckpoint('./tmp/{epoch:02d}-{val_loss:.2f}'))

model.fit(train_ds, 
          epochs=3, 
          validation_data=valid_ds, 
          callbacks=callbacks)

%ls -R tmp

回答解説

チェックポイントのコールバック

kerasのコールバック関数としてModelCheckpoint関数を使っています。引数filepathには変数を使えて、{}で囲って指定しています。細かい文法や使える変数がヘルプ文書にないので、GitHubのコードを見るしかないでしょうか。調べていません。

callbacks = []
callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR))
callbacks.append(tf.keras.callbacks.ModelCheckpoint('./tmp/{epoch:02d}-{val_loss:.2f}'))

model.fit(train_ds, 
          epochs=3, 
          validation_data=valid_ds, 
          callbacks=callbacks)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?