3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow.jsでDeepLearning(Handwritten digit recognition with CNNs)

Last updated at Posted at 2019-05-14

こんにちわ。Electric Blue Industries Ltd.という、ITで美を追求するファンキーでマニアックなIT企業のマッツと申します。TensorFlow.jsでDeepLearningのチュートリアル「Handwritten digit recognition with CNNs」の詳細解説です。

これはTensorFlow.JSの公式サイトにある「TensorFlow.js — Handwritten digit recognition with CNNs」をコードの中に記載したコメントで詳細に解説したものです。解説の利便性によりコードの部分の位置関係は変更してありますが、内容に変化はありません。実際に動作するデモはこちらで見られます。

#1. コード
##1.1. html
ライブラリを読み込んでDeep LearningのためのJavaScriptを実行するHTMLです。

展開してコードを見る
index.html
<html>
<head>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>TensorFlow.js Tutorial</title>

    <!-- Import TensorFlow.js -->
    <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.js"></script>
    <!-- Import tfjs-vis -->
    <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
    <!-- Import the main script file -->
    <script src="script.js" type="module"></script>

</head>
<body>
</body>
</html>

##1.2. JavaScript
上記のHTMLにインクルードされてDeep Learning処理を行うJavaScriptです。リリースされたオリジナルのコードでは、学習データ(画像)を取得してバイナリー処理をしてモデルの入出力となるテンソルを生成するコードは別にあってインクルードするだけですが、その処理に画像データ処理の本質があると考えて、下記のコードに展開して解説をしてあります。

展開してコードを見る
script.js
/******************************************************************
TensorFlow.js — Handwritten digit recognition with CNNs
url: https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#0
filename: script.js
copyrighted to: tensorflow.org
composed by: Mats (Electric Blue Industries Ltd.)
description: 手書きの数字(0から9)を画像として読んでどの数字か判別する
******************************************************************/

// 1画像を784(=28x28)ピクセルで表現する
const IMAGE_SIZE = 784;

// 0から9の数字を判別するので、合計10個の分類に呼応する10個のラベル分類を持つ
const NUM_CLASSES = 10;

// 読みこむ元データ画像1つには合計65,000枚の画像が含まれます。
// モデルのトレーニングには最大55,000枚の画像を使用し、モデルのパフォーマンスをテストするために使用できる10,000枚の画像を保存します。

// 読み込みの利便性を考慮して作られた少し特殊な元データで、65,000個以上のの手書き数字の画像データを1つの画像にまとめたもので、
// 学習データの画像ピクセルを一列に並べて別の大きな1画像となったもの(実サイズが784px × 65000px)。
// 今回のケースでは横1行が学習データ1枚の全画素(28x28=784)となっている。
const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
// 上記学習データのラベルとして符号なし8ビット整数(0と正の数)
const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

// 上記の元データ画像1つから生成する数字画像データ数を65,000個とする
const NUM_DATASET_ELEMENTS = 65000;
// 元データ画像1つから生成する数字画像データ数のうち 5/6 をテスト用とする
const TRAIN_TEST_RATIO = 5 / 6;
// 学習用に用いる画像の数
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS); // floor(5/6*65000) = 54166
// 学習後のテストに使う画像の数
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;    // 65000-54166 = 10834

//******************************************************************
// 分割されたMNISTデータセットを取得してシャッフルされたバッチを返すクラス
// NOTE: これはずっと簡単になります。今のところ、データの取得と操作は手動で行います。
//******************************************************************

class MnistData {

    async load() {

        // MNISTのスプリットイメージ(上記で定義した28x28ピクセルを画素とした時に各画素に呼応する色彩の値)を要求します。
        
        const img = new Image();
        // まず、画面には描画しないが途中処理用としてcanvasを用意する
        const canvas = document.createElement('canvas');
        // キャンバスに描画する画像は二次元データとする
        const ctx = canvas.getContext('2d');
        
        const imgRequest = new Promise((resolve, reject) => {
            
            // 学習データの元画像を取得
            img.src = MNIST_IMAGES_SPRITE_PATH;

            // CORS (Cross-Origin Resource Sharing / クロスドメイン通信) 設定属性です。
            // 別オリジンから読み込んだ画像などのリソースを文書内で利用する際のルールを指定します。
            img.crossOrigin = '';

            // 学習データの元画像を取得できた場合に下記を実行
            img.onload = () => {

                // 画像の表示幅ではなく、画像のデータ上の本当の幅(784px)を画像の特性値とする。
                img.width = img.naturalWidth;
                // 画像の表示高ではなく、画像のデータ上の本当の高さ(65000px)を画像の特性値とする。
                img.height = img.naturalHeight;

                // ArrayBuffer(n)は8bit(=1byte)がn個入るバッファ領域が用意する。
                // 後に1画素の情報を32bit(=4byte)で扱うので、4倍して65000枚 x 28画素 x 28画素 x 4 byteの容量のバッファを用意。
                // 具体的には203,840,000byte(だいたい200MegaByte)のバッファ領域が確保される。
                // なお、バッファを確保した時点では全ての要素は0が入っている状態。
                const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
                
                // this.datasetImagesという配列には50,960,000個の要素があり、各要素が65,000枚の画像の全ての画素28x28個分セットに
                // 呼応して1次元配列で保存されている。各要素は4byte(=32bit)の情報で構成されており、「50,960,000個 x 4byte = 203,840,000byte」が
                // この配列のデータサイズ(確保したバッファ領域と等しい)。
                this.datasetImages = new Float32Array(datasetBytesBuffer);

                // 1枚の元画像データを手書き数字の画像データ5000枚分ごとに切り出すと指定
                const chunkSize = 5000;
                canvas.width = img.width;
                canvas.height = chunkSize;

                // 65000枚分の元データ画像を5000枚分ごとに切り出していくループなので 65000/5000 = 13回ループすることになる
                for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {

                    // 要素に32ビット浮動小数を要素として格納する配列を用意
                    // Float32Array(buffer, byteOffset, length)はbufferで確保したバッファ領域に対し、前からbyteOffset番目の場所から
                    // lengthの幅をデータ保存領域として確保することを意味する。
                    const datasetBytesView = new Float32Array(datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize);

                    // ctx.drawImage(image, sx, sy, sw, sh, dx, dy, dw, dh)は
                    // imgで示す画像の左上から(sx,sy)の点を起点として幅sw高さshを切り取って、
                    // canvasの左上から(dx,dy)の点を起点として幅dw高さdhの領域に描画することを意味する。
                    // この場合は、元画像データの上から数えて、当該のチャンクに相当する部分の画像部分をcanvasの左上を起点として貼る行為。
                    ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);

                    // canvasに描画した1チャンク分(=画像5000枚分)のイメージデータを丸ごと取得しimageDataに格納する。
                    // なお、canvasのイメージデータは1画素を「赤・緑・青・透明度」の4種類の情報を各1byteで表現した4バイトデータである。
                    // さらにいえば、1byteは2の8乗であり、0から255までの数字で表現される。
                    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

                    // 取得したイメージデータ、1画素を4byteで表現するので、画素の数を数えるには配列要素数を4で割ることになる
                    for (let j = 0; j < imageData.data.length / 4; j++) {

                        // 取り込んだ画像データの画素はモノクロなので「赤・緑・青」の3チャンネルは同じ数字であるので、赤の情報だけ4つ飛ばしで読み込む。
                        // また、読んだ値は0から255までの数値なので、0から1に収めるために255で割っておく。
                        // こうして、dataBytesViewと言う配列には65000枚の全画像の全画素の赤の濃さが0から1までの数値で収まった状態。
                        datasetBytesView[j] = imageData.data[j * 4] / 255;

                    }

                }

                resolve();

            };

        });

        // fetchを使って非同期で学習データのラベルを取得
        // このデータは0から650000の数値データ(8bit表現の整数)が入っており、これを650000枚の各画像のラベルに用いる(順番は画像データの順番に呼応している)
        const labelsRequest = fetch(MNIST_LABELS_PATH);
        
        // 上記の画像取得&処理およびラベル取得処理が正常に完了したらそのレスポンスを格納しておく
        const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]);
        
        // labelsResponse.arrayBuffer() はlabelsResponseで得た値(650000個の8bit数値)が入っているので、これを1要素8bitの配列に格納する
        // そしてこの650000要素からなるデータラベルの入った配列をdatasetLabelsとする。
        this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

        // tf.util.createShuffledIndices(n)は0から(n-1)までの整数をシャッフルして並べた32bit配列を返す関数。
        this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
        this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

        // Array.slice(x, y)は配列のインデックス番号(0で始まる)で数えてx番目からy番目までを切り抜いた配列を返す関数。
        // datasetImagesに収めた先頭から54166枚目分までを切り出して学習用画像データとして配列に保存
        this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
        // datasetImagesに収めた54167枚目分から650000枚目分までを切り出してテスト用画像データとして配列に保存
        this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
        
        // 上記と同様に学習用とテスト用の画像のラベルも別配列として切り出す
        this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
        this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);

    }
    
    // コンストラクタメソッドは、クラス内で作成されたオブジェクトを作成および初期化するための特別なメソッドです。
    // MnistDataというクラスの中でオブジェクトを作成し初期化する。
    constructor() {

        this.shuffledTrainIndex = 0;
        this.shuffledTestIndex = 0;

    }
    
    // トレーニングセットから画像とそのラベルのランダムなバッチを返します。
    nextTrainBatch(batchSize) {

        // 下記で定義したnextBatch関数
        return this.nextBatch(batchSize, [this.trainImages, this.trainLabels], () => {

            this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length;
            return this.trainIndices[this.shuffledTrainIndex];

        });

    }

    // テストセットから画像とそのラベルのランダムなバッチを返します。
    nextTestBatch(batchSize) {

        // 下記で定義したnextBatch関数
        return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {

            this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length;
            return this.testIndices[this.shuffledTestIndex];

        });

    }
    
    nextBatch(batchSize, data, index) {

        // まずバッチサイズに含まれる画像データ(28x28画像)と容量が等しい32bit要素配列を器として作成
        const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
        // 次に、各画像が1から9のどの数字の確率が高いかの評価の器としてバッチサイズに含まれる画像データ数 x 10個分の要素を持つ8bit要素配列を作成
        const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

        // 画像およびラベルのデータの塊から
        for (let i = 0; i < batchSize; i++) {

            const idx = index();

            const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
            // Array.set(x, y)でxをArrayのy番目要素の位置に差し込む。
            // batchImagesArrayの「i * IMAGE_SIZE」番目要素として上記で切り出した画像の784画素のデータを流し込む。
            // この時点でbatchImagesArrayという配列は5000x784個の要素を格納する1次元配列を形成していく。
            batchImagesArray.set(image, i * IMAGE_SIZE);

            // 上記と同様にラベルについて1次元配列を形成していく。
            const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
            batchLabelsArray.set(label, i * NUM_CLASSES);

        }

        // tf.tensor2d("一次元配列", "2次元テンソルのshape")
        // 下記により画像の画素情報が格納された5000行784列の二次元テンソルが生成される。
        const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
        // 下記により画像のラベル情報が格納された5000行10列の二次元テンソルが生成される。
        const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

        // 上記で得られた2つのテンソルを格納したオブジェクトを返す。
        return {xs, labels};

    }

}

async function run() {
    
    // 上記で定義したクラス「MnistData」を取り込み。
    // 結果としてMnistData内でload()関数が実行されてArrayBufferに学習データおよびテストデータの生データがdataとして格納されている状態。
    // 学習およびテスト用の{xs, labels}形式のオブジェクトはまだ生成されていない。
    const data = new MnistData();
    await data.load();
    
    //******************************************************************
    // 1. showExamples: 学習データの例を画面に表示する
    //******************************************************************
    
    await showExamples(data);
    
    async function showExamples(data) {

        // TF-VISで学習データ画像の例を表示する画面の枠組みを決める。
        const surface = tfvis.visor().surface({
            name: 'Input Data Examples',
            tab: 'Input Data'
        });  

        // 画像65000枚分のデータから20個分を切り出して、{[20, 784], [20, 10]}という2種類のshapeのテンソルが格納されたオブジェクトを得る。
        const examples = data.nextTestBatch(20);
        // examples.xs.shapeという配列は[20, 784]という「20行784列」を意味する値を持っているのでexamples.xs.shape[0]の値は「20」となる。
        // ちなみにexamples.labels.shapeは[20, 10]という「20行10列」を意味する値を持っているのでexamples.labels.shape[0]の値も「20」となる。
        const numExamples = examples.xs.shape[0]; // 20

        // Create a canvas element to render each example
        for (let i = 0; i < numExamples; i++) {

            const imageTensor = tf.tidy(() => {

                // examples.xs.shape[1]は784
                // i個目の画像の最初の画素から784個目の画素を取得して[28, 28, 1]の形のテンソルに再形成する。
                return examples.xs.slice([i, 0], [1, examples.xs.shape[1]]).reshape([28, 28, 1]);

            });

            // canvasに立て続けに投げ込んで描画
            const canvas = document.createElement('canvas');
            canvas.width = 28;
            canvas.height = 28;
            canvas.style = 'margin: 4px;';
            await tf.browser.toPixels(imageTensor, canvas);
            surface.drawArea.appendChild(canvas);

            imageTensor.dispose();

        }

    }

    //******************************************************************
    // 2. getModel: モデルの枠組みの作成
    //******************************************************************

    const model = getModel();

    function getModel() {

        // シーケンシャルモデルの枠組みを作成
        const model = tf.sequential();

        // 上記のクラス内で取得し作った学習データに合わせて入力テンソルとなる画像データのshapeを定義
        const IMAGE_WIDTH = 28;
        const IMAGE_HEIGHT = 28;
        const IMAGE_CHANNELS = 1;  

        // tf.layers.conv2d()
        //
        // アウトコンボリューションニューラルネットワークの最初の層では、入力形状を指定する必要があります。
        // 次に、このレイヤで行われる畳み込み演算のためのいくつかのパラメータを指定します。
        //
        // inputShape:  定義されている場合、このレイヤの前に挿入する入力レイヤを作成するために使用されます。
        //              inputShapeとbatchInputShapeの両方が定義されている場合は、batchInputShapeが使用されます。
        //              この引数は入力レイヤ(モデルの最初のレイヤ)にのみ適用できます。
        // kernelSize:  畳み込みウィンドウの大きさ。
        //              kernelSizeが数値の場合、たたみ込みウィンドウは正方形になります。
        // filters:     出力空間の次元数(すなわち、畳み込みにおけるフィルタの数)。
        // strides:     各次元における畳み込みの進歩。
        //              ストライドが数値の場合、両方の次元のストライドは等しくなります。
        // activation:  レイヤーの活性化関数の種類。
        // kernelInitializer:   畳み込みカーネル重み行列の初期化子。

        model.add(tf.layers.conv2d({
            inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
            kernelSize: 5,
            filters: 8,
            strides: 1,
            activation: 'relu',
            kernelInitializer: 'varianceScaling'
        }));

        // tf.layers.maxPooling2d()
        //
        // MaxPoolingレイヤーは、平均化の代わりに領域内の最大値を使用する一種のダウンサンプリングとして機能します。
        //
        // poolSize:    各次元で縮小する倍率[垂直、水平]。
        //              整数または2つの整数の配列が必要です。
        // strides:     プールウィンドウの各次元におけるストライドのサイズ。
        //              整数または2つの整数の配列が必要です。
        //              整数、2整数のタプル、またはNone。

        model.add(tf.layers.maxPooling2d({
            poolSize: [2, 2],
            strides: [2, 2]
        }));

        // 別のconv2d + maxPoolingスタックを繰り返します。コンボリューションにはもっとフィルタがあることに注意してください。

        model.add(tf.layers.conv2d({
            kernelSize: 5,
            filters: 16,
            strides: 1,
            activation: 'relu',
            kernelInitializer: 'varianceScaling'
        }));

        model.add(tf.layers.maxPooling2d({
            poolSize: [2, 2],
            strides: [2, 2]
        }));

        // これで、2Dフィルタからの出力を1Dベクトルにフラット化して、最後のレイヤーに入力できるようにしました。
        // これは、高次元のデータを最終的な分類出力レイヤに送るときに一般的な方法です。

        model.add(tf.layers.flatten());

        // 10個の出力 (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
        const NUM_OUTPUT_CLASSES = 10;

        // tf.layers.dense()
        //
        // units:   正の整数、出力スペースの次元数
        // kernelInitializer:   密カーネル加重行列の初期化子
        // activation:  使用する活性化関数

        model.add(tf.layers.dense({
            units: NUM_OUTPUT_CLASSES,
            kernelInitializer: 'varianceScaling',
            activation: 'softmax'
        }));


        // オプティマイザーとして「アダムオプティマイザー」を選択
        const optimizer = tf.train.adam();
        
        // 損失関数と精度のメトリックス(カテゴリカル交差エントロピー)を指定し、モデルをコンパイル
        model.compile({
            optimizer: optimizer,
            loss: 'categoricalCrossentropy',
            metrics: ['accuracy'],
        });

        return model;

    }
    
    // 作成したモデルの概要をTFVISで描画
    tfvis.show.modelSummary({
        name: 'Model Architecture'
    }, model);
    
    //******************************************************************
    // 3. train: 作成したモデルの学習
    //******************************************************************

    await train(model, data);
    
    async function train(model, data) {

        // TFVISでの描画用に指標とグラフ枠を定義
        const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
        const container = {
            name: 'Model Training',
            styles: {
                height: '1000px'
            }

        };

        const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

        // 学習データの作成時にはバッチサイズを5000にして作成したが、学習時は512(2の累乗にするのが慣習)で行う。
        const BATCH_SIZE = 512;
        const TRAIN_DATA_SIZE = 5500;
        const TEST_DATA_SIZE = 1000;

        const [trainXs, trainYs] = tf.tidy(() => {

            // 上記で定義したnextTrainBatchで学習用入力テンソルと出力ラベルを5500個生成
            const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
            return [
                // d.xsは5500行784列の行列なので、これをモデルの入力に合うように5500個の[28, 28, 1]のテンソル([5500, 28, 28, 1])にreshapeする。
                d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
                // 出力は元から[5500, 10]の状態なのでreshape不要
                d.labels
            ];

        });
        
        console.log([trainXs, trainYs]);

        const [testXs, testYs] = tf.tidy(() => {

            // 上記で定義したnextTestBatchでテスト用入力テンソルと出力ラベルを1000個生成
            const d = data.nextTestBatch(TEST_DATA_SIZE);
            return [
              d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
              d.labels
            ];

        });

        // model.fit(x, y, args)
        //
        // batchSize:       勾配更新ごとのサンプル数。
        //                  指定しない場合は、デフォルトの32になります。
        // validationData:  各エポックの終わりに損失とモデルメトリックを評価するためのデータ。
        //                  モデルはこのデータについてトレーニングされません。
        //                  これはタプル[xVal、yVal]またはタプル[xVal、yVal、valSampleWeights]です。
        //                  モデルはこのデータについてトレーニングされません。
        //                  validationDataはvalidationSplitをオーバーライドします。
        // epochs:          トレーニングデータ配列を反復する回数。
        // shuffle:         各エポックの前にトレーニングデータをシャッフルするかどうか。
        //                  stepsPerEpochがnullでない場合は無効です。
        // callbacks:       トレーニング中に呼び出されるコールバックのリスト。
        //                  onTrainBegin、onTrainEnd、onEpochBegin、onEpochEnd、onBatchBegin、onBatchEndの1つ以上のフィールドで構成できます。
        
        return model.fit(trainXs, trainYs, {
            batchSize: BATCH_SIZE,
            validationData: [testXs, testYs],
            epochs: 10,
            shuffle: true,
            callbacks: fitCallbacks
        });
    }
    
    //******************************************************************
    // 4. showAccuracy & showConfusion
    //******************************************************************

    const classNames = ['Zero', 'One', 'Two', 'Three', 'Four', 'Five', 'Six', 'Seven', 'Eight', 'Nine'];

    // 4.1. showAccuracy: 学習したモデルに500個の画像を読ませて、その実際の精度を表にして表示する
    
    await showAccuracy(model, data);

    // showAccuracy: TF-VISを用いて学習させたモデルの予測精度を描画する関数
    async function showAccuracy(model, data) {

        const [preds, labels] = doPrediction(model, data);

        const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
        const container = {
            name: 'Accuracy',
            tab: 'Evaluation'
        };
        tfvis.show.perClassAccuracy(container, classAccuracy, classNames);

        labels.dispose();

    }

    // 4.2. showConfusion: 学習したモデルに再度500個の画像を読ませて、実際のラベルと出力値の関係(コンフージョン・マトリックス)を表示する

    await showConfusion(model, data);
    
    // showConfusion: TF-VISを用いてコンフージョンマトリックスを描画する関数
    async function showConfusion(model, data) {

        const [preds, labels] = doPrediction(model, data);
        const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
        const container = {
            name: 'Confusion Matrix',
            tab: 'Evaluation'
        };
        tfvis.render.confusionMatrix(
          container, {values: confusionMatrix}, classNames);

        labels.dispose();

    }
    
    // doPrediction: 学習したモデルに指定した個数のテストデータを読ませて予測出力を返す関数
    function doPrediction(model, data, testDataSize = 500) {

        const IMAGE_WIDTH = 28;
        const IMAGE_HEIGHT = 28;
        const testData = data.nextTestBatch(testDataSize);
        const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
        const labels = testData.labels.argMax([-1]);
        const preds = model.predict(testxs).argMax([-1]);

        testxs.dispose();
        return [preds, labels];

    }
    
}

document.addEventListener('DOMContentLoaded', run);

#2. 実行結果
##2.1. 入力データ

2019-05-14_16-10-00.png

##2.2. モデル
###2.2.1. 要約

2019-05-14_16-10-59.png

###2.2.2. 学習

2019-05-14_16-11-39.png

2019-05-14_16-12-03.png

##2.3. 予測精度

2019-05-14_16-12-52.png

2019-05-14_16-13-19.png

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?