0
2

More than 3 years have passed since last update.

ketosで音響解析その2

Posted at

はじめに

前回の記事でデータベースを作成したので今回はそれを基に検出器を作成していきます。

必要なデータはここにあります。
前回作成したデータベースや今回のコードも入っています。

分類器のトレーニング

まずはチュートリアルで使用されるランダムシードを定義する必要があるようです。

import numpy as np
np.random.seed(1000)

import tensorflow as tf
tf.random.set_seed(2000)

インポートする

今回使うものです。

import ketos.data_handling.database_interface as dbi
from ketos.neural_networks.resnet import ResNetInterface
from ketos.data_handling.data_feeding import BatchGenerator

トレーニング

まずはデータベースに接続します。

db = dbi.open_file("database.h5", 'r')

次にdb内のテーブルを使えるようにしておきます。

train_data = dbi.open_table(db, "/train/data")
val_data = dbi.open_table(db, "/val/data")

学習させるためのバッチサイズなどの指定をketosに沿って設定していきます。

def transform_batch(X, Y):
    x = X.reshape(X.shape[0],X.shape[1],X.shape[2],1)
    y = tf.one_hot(Y['label'], depth=2, axis=1).numpy()
    return x, y
train_generator = BatchGenerator(batch_size=128, data_table=train_data, 
                                  output_transform_func=ResNetInterface.transform_batch,
                                  shuffle=True, refresh_on_epoch_end=True)

val_generator = BatchGenerator(batch_size=128, data_table=val_data,
                                 output_transform_func=ResNetInterface.transform_batch,
                                 shuffle=True, refresh_on_epoch_end=False)

ニューラルネットワークの作成とトレーニング

チュートリアルではResNetを使用します。

下準備

resnet = ResNetInterface.build_from_recipe_file("recipe.json")
resnet.train_generator = train_generator
resnet.val_generator = val_generator
resnet.checkpoint_dir = "checkpoints"

学習開始

resnet.train_loop(n_epochs=30, verbose=True)

学習が完了したらデータベースを閉じ、モデルを保存します。

db.close()
resnet.save_model('narw.kt',audio_repr_file='spec_config.json')

今回はこれで終了します。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2