LoginSignup
0
1

More than 3 years have passed since last update.

TensorFlow.jsでDeepLearning(Audio recognition using transfer learning)

Last updated at Posted at 2019-05-18

こんにちわ。Electric Blue Industries Ltd.という、ITで美を追求するファンキーでマニアックなIT企業のマッツと申します。TensorFlow.jsでDeepLearningのチュートリアル「Audio recognition using transfer learning」の詳細解説です。

これはTensorFlow.JSの公式サイトにある「TensorFlow.js — Audio recognition using transfer learning」をコードの中に記載したコメントで詳細に解説したものです。解説の利便性によりコードの部分の位置関係は変更してありますが、内容に変化はありません。実際に動作するデモはこちらで見られます。

1. コード

1.1. html

ライブラリを読み込んでDeep LearningのためのJavaScriptを実行するHTMLです。

展開してコードを見る
index.html
<html>
    <head>
        <!-- Import TensorFlow.js -->
        <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.js"></script>
        <!-- Import tfjs-vis -->
        <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
        <!-- imports the pre-trained Speech Commands model -->
        <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands@0.3.5/dist/speech-commands.min.js"></script>
    </head>
    <body>
        <!-- 3種類の言葉を音声で入力する際に押すボタン(ラベルに呼応) -->
        <button id="left" onmousedown="collect(0)" onmouseup="collect(null)">Left</button>
        <button id="right" onmousedown="collect(1)" onmouseup="collect(null)">Right</button>
        <button id="noise" onmousedown="collect(2)" onmouseup="collect(null)">Noise</button>
        <br/><br/>
        <!-- モデルの学習を行うトリガー -->
        <button id="train" onclick="train()">Train</button>
        <br/><br/>
        <!-- 学習が済んだモデルで音声入力の該当するラベル(left/right/noise)をpredictするトリガー -->
        <button id="listen" onclick="listen()">Listen</button>
        <input type="range" id="output" min="0" max="10" step="0.1">
        <br/><br/>
        <!-- ここにpredist結果のラベル(0/1/2)が出力される -->
        <div id="console"></div>
        <script src="script.js"></script>
    </body>
</html>

注意事項:オリジナルのコードではSpeech Commandsのversion0.3.6がHTMLにインクルードされていますが、そのままだと正常に実行されません。version0.3.5をインクルードしてください。

1.2. JavaScript

上記のHTMLにインクルードされてDeep Learning処理を行うJavaScriptです。

展開してコードを見る
script.js
/******************************************************************
TensorFlow.js — Transfer learning audio recognizer
url: https://www.tensorflow.org/js/tutorials/transfer/audio_recognizer#0
filename: script.js
copyrighted to: tensorflow.org
composed by: Mats (Electric Blue Industries Ltd.)
description: 移転学習で学習済みパラメータを取り込み、ブラウザでその場でトレーニングするカスタムの音声分類器であるを構築する
important note: this set of code works with Speech Command ver. 0.3.5 but does not work with ver. 0.3.6. 
******************************************************************/

let recognizer;

async function app() {

    // speechCommandでブラウザ入力音声を高速フーリエ変換をおこなって音声をデータ化する音声認識機能を新規作成する。
    recognizer = speechCommands.create('BROWSER_FFT');

    // 上記の新規作成により下記のような特性のFFT音声認識機能が作成される(一部情報のみ)
    //
    // FFT_SIZE: 1024
    // MODEL_URL_PREFIX: "https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft"
    // SAMPLE_RATE_HZ: 44100
    // elementsPerExample: 9976
    // nonBatchInputShape: (3) [43, 232, 1]

    // 既に学習済みのモデルパラメータが読み込まれたのを確認して次のステップへ。
    await recognizer.ensureModelLoaded();

    // 描き関数を用いてモデル枠を作成する。
    buildModel();

}

//******************************************************************
// 1. buildModel: モデルを生成する
//******************************************************************

const NUM_FRAMES = 3;
// 入力テンソルの形状(これは移転学習する学習済みモデルによって決まるものなので移転元に合わせる)
const INPUT_SHAPE = [NUM_FRAMES, 232, 1];
let model;

// 上記のapp()内で用いられる、モデル作成の関数
function buildModel() {

    // シーケンシャルモデル(線形回帰モデル)の枠組みの作成
    model = tf.sequential();

    // 震度方向2次元畳み込み層を追加
    model.add(tf.layers.depthwiseConv2d({
        depthMultiplier: 8,
        kernelSize: [NUM_FRAMES, 3],
        // 活性化関数はRelu(=Rectified Linear Unit)すなわちランプ関数/正規化線形関数
        activation: 'relu',
        inputShape: INPUT_SHAPE
    }));

    // 二次元マックスプーリング層の追加
    model.add(tf.layers.maxPooling2d({
        poolSize: [1, 2],
        strides: [2, 2]
    }));

    // 出力を1次元にreshapeする層を追加
    model.add(tf.layers.flatten());

    // 上記のflat化層から入力を受けて3つの出力(Left/Right/Noise)を出力する層を追加
    model.add(tf.layers.dense({
        // ユニット(別名:ノード)は3個
        units: 3,
        // 活性化関数はソフトマックス
        activation: 'softmax'
    }));

    // オプティマイザーとして「アダムオプティマイザー」を選択
    const optimizer = tf.train.adam(0.01);

    // 損失関数と精度のメトリックス(カテゴリカル交差エントロピー)を指定し、モデルをコンパイル
    model.compile({
        optimizer,
        loss: 'categoricalCrossentropy',
        metrics: ['accuracy']
    });

    // 上記で作成したモデルの要約情報(Layer Name, Output Shape, # Of Params, Trainable)表示
    // 注)この動作はGoogleが公開したオリジナルコードにはありません
    tfvis.show.modelSummary({name: 'Model Summary'}, model);

}

//******************************************************************
// 2. collect: ブラウザから追加の学習データとなる音声を取得する
//******************************************************************

// 録音して学習させる音声のテンソルを格納する配列としてexampleを用意。
let examples = [];

// ブラウザ画面にある「left」「right」「noise」ボタンが押された際にマイクの音を拾って学習データを生成する関数
// ラベル付けは「left」が「0」、「right」が「1」、「noise」が「2」とする。
function collect(label) {

    // もし音声取得が既に行われている状況だったら、音声取得を停止する。
    if (recognizer.isListening()) {

        return recognizer.stopListening();

    }

    // onMouseUpで音声取得を終了する。
    if (label == null) {

        return recognizer.stopListening();

    }

    // onMouseDownの際にマイクが拾った音声をデータ化してラベルを付けてオブジェクトとして学習データexample配列に格納する
    recognizer.listen(async ({spectrogram: {frameSize, data}}) => {

        // 入力された音声を正規化
        let vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
        // 配列に追加して格納
        examples.push({vals, label});
        // 画面には拾った音声データの内容を表示
        document.querySelector('#console').textContent = `${examples.length} examples collected`;

    }, {

        // recognizer.listenの挙動を制御するパラーメータの様子(詳細不明)
        overlapFactor: 0.999,
        includeSpectrogram: true,
        invokeCallbackOnNoiseAndUnknown: true

    });

}

// exampleという配列に上記のcollectで追加された学習データを含めてモデルに学習させる関数。
// ブラウザの「train」ボタンをトリガーにして動作する。
async function train() {

    toggleButtons(false);
    const ys = tf.oneHot(examples.map(e => e.label), 3);
    const xsShape = [examples.length, ...INPUT_SHAPE];
    const xs = tf.tensor(flatten(examples.map(e => e.vals)), xsShape);

    await model.fit(xs, ys, {
    batchSize: 16,
    epochs: 10,
    callbacks: {
        onEpochEnd: (epoch, logs) => {
            document.querySelector('#console').textContent =
           `Accuracy: ${(logs.acc * 100).toFixed(1)}% Epoch: ${epoch + 1}`;
        }
    }
    });

    tf.dispose([xs, ys]);
    toggleButtons(true);

}

// マイクから音を拾って、その入力音声が何であるかを既成モデルを使って判別する関数。
// ブラウザの「listen」ボタンをトリガーにして動作する。
function listen() {

    // 既に音声を取得中の状態の場合
    if (recognizer.isListening()) {

        // 音声取得を停止する
        recognizer.stopListening();
        // ブラウザ画面の「Stop」というボタン表記を「Listen」に変える
        toggleButtons(true);
        document.getElementById('listen').textContent = 'Listen';
        return;

    }

    toggleButtons(false);

    document.getElementById('listen').textContent = 'Stop';
    document.getElementById('listen').disabled = false;

    recognizer.listen(async ({spectrogram: {frameSize, data}}) => {

        const vals = normalize(data.subarray(-frameSize * NUM_FRAMES));
        const input = tf.tensor(vals, [1, ...INPUT_SHAPE]);
        const probs = model.predict(input);
        const predLabel = probs.argMax(1);
        await moveSlider(predLabel);
        tf.dispose([input, probs, predLabel]);

    }, {

        overlapFactor: 0.999,
        includeSpectrogram: true,
        invokeCallbackOnNoiseAndUnknown: true

    });

}

// 正規化のための関数
function normalize(x) {

    const mean = -100;
    const std = 10;
    return x.map(x => (x - mean) / std);

}

// ブラウザに表示するtrainとlistenのボタンをon/offするための表面的な機能向け関数
function toggleButtons(enable) {

    document.querySelectorAll('button').forEach(b => b.disabled = !enable);

}

// 多次元テンソルを1次元にreshapeする関数
function flatten(tensors) {

    const size = tensors[0].length;
    const result = new Float32Array(tensors.length * size);
    tensors.forEach((arr, i) => result.set(arr, i * size));
    return result;

}

// 入力される音声がどのラベルに対応するか(leftなら0、rightなら1、ノイズなら2)を0から9の幅で
// リアルタイムに表示するためスライドバーを逐次に動かす関数
async function moveSlider(labelTensor) {

    const label = (

        await labelTensor.data()

    )[0];
    document.getElementById('console').textContent = label;

    if (label == 2) {

        return;

    }

    let delta = 0.1;
    const prevValue = +document.getElementById('output').value;
    document.getElementById('output').value =
     prevValue + (label === 0 ? -delta : delta);

}

app();


追伸: Machine Learning Tokyoと言うMachine Learningの日本最大のグループに参加しています。作業系の少人数会合を中心に顔を出しています。基本的に英語でのコミュニケーションとなっていますが、能力的にも人間的にもトップレベルの素晴らしい方々が参加されておられるので、機会がありましたら参加されることをオススメします。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1