はじめに
前回の記事でデータベースを作成したので今回はそれを基に検出器を作成していきます。
必要なデータはここにあります。
前回作成したデータベースや今回のコードも入っています。
分類器のトレーニング
まずはチュートリアルで使用されるランダムシードを定義する必要があるようです。
import numpy as np
np.random.seed(1000)
import tensorflow as tf
tf.random.set_seed(2000)
インポートする
今回使うものです。
import ketos.data_handling.database_interface as dbi
from ketos.neural_networks.resnet import ResNetInterface
from ketos.data_handling.data_feeding import BatchGenerator
トレーニング
まずはデータベースに接続します。
db = dbi.open_file("database.h5", 'r')
次にdb内のテーブルを使えるようにしておきます。
train_data = dbi.open_table(db, "/train/data")
val_data = dbi.open_table(db, "/val/data")
学習させるためのバッチサイズなどの指定をketosに沿って設定していきます。
def transform_batch(X, Y):
x = X.reshape(X.shape[0],X.shape[1],X.shape[2],1)
y = tf.one_hot(Y['label'], depth=2, axis=1).numpy()
return x, y
train_generator = BatchGenerator(batch_size=128, data_table=train_data,
output_transform_func=ResNetInterface.transform_batch,
shuffle=True, refresh_on_epoch_end=True)
val_generator = BatchGenerator(batch_size=128, data_table=val_data,
output_transform_func=ResNetInterface.transform_batch,
shuffle=True, refresh_on_epoch_end=False)
ニューラルネットワークの作成とトレーニング
チュートリアルではResNetを使用します。
下準備
resnet = ResNetInterface.build_from_recipe_file("recipe.json")
resnet.train_generator = train_generator
resnet.val_generator = val_generator
resnet.checkpoint_dir = "checkpoints"
学習開始
resnet.train_loop(n_epochs=30, verbose=True)
学習が完了したらデータベースを閉じ、モデルを保存します。
db.close()
resnet.save_model('narw.kt',audio_repr_file='spec_config.json')
今回はこれで終了します。