LoginSignup
8
8

More than 5 years have passed since last update.

TensorFlow.jsを使って、Webで画像による本人認証機能を実装した話。

Last updated at Posted at 2018-08-22

概要

2018年3月に発表され、TensorFlow.js!!
TensorFlowとjavascriptが合わさったということで初めて知った時にはかなりのパワーワードでした 笑
面白がってデモを触ってコードをみていたところ、「あれ、これいけるんじゃね?」と思い立ち作ってみました!

作ったもの

FasePass
ブラウザでの画像認証を実装したWebアプリケーション
画像のアップと認証ができる。

使用技術

インフラ周り: heroku
サーバー: Express
フロント: React
データベース: MongoDB

Webカメラの取りこみおよび撮影は偉大なGoogle様から拝借いたしました。
https://js.tensorflow.org/tutorials/webcam-transfer-learning.html

フローチャート

画像をDBにアップロード

SignUp画面でパソコンの付属カメラから画像を撮影してDBに保存する
クライエント側でPC付属のカメラから画像を撮影してDBに保存する処理を行なっている。学習の前処理もかねており、画像そのままでは、学習が遅れるためGoogleのモデルを用いて数千次元のベクトルに変換している。

server.js
require('@tensorflow/tfjs')
const tf = require('@tensorflow/tfjs-node');
const mongoose = require('mongoose');
const User = mongoose.model('User');
module.exports = app => {
  app.post('/api/add_face_data', async (req, res) => {

    // 画像データがDBに300枚以上ある場合には削除する。
    const AllImage = User.find({})
    const number = AllImage.length
    if (number >= 300) {
      const image_id = await AllImage.sort({created_at: -1}).limit(1)[0]._id
      User.deleteOne({_id: image_id})
    }

    const newUser = await User.create({email: req.body.email, x_data: req.body.x, created_at: Date.now()})
    if (newUser) {
      const AllImage = await User.find({email: req.body.email})
      const number = AllImage.length
      return res.json({message: "Goob job!", newUser: newUser, dataAmount: number})
    }
front.js
    async componentDidMount () {
      await webcam.setup();
      mobilenet = await this.loadMobilenet();
    }

    async loadMobilenet () {
      // Googleのモデルを取得
      const mobilenet = await tf.loadModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
      const layer = mobilenet.getLayer('conv_pw_13_relu');
      return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
    }

    // 撮影時に発火する関数
    async handleOnClick () {
      this.setState({processing: true})
      const {name} = this.state

      // カメラからデータを取得し、Googleのモデルで変換
      const xs = mobilenet.predict(webcam.capture())
      const x_data = await xs.data()
      let params = {
        email: name,
        x: x_data.toString()
      }
      const res = await axios.post('/api/add_face_data', params)
      if (res) {
        this.setState({processing: false, photoNum: res.data.dataAmount})
      }
    }

正直画像のDBへの保存が一番苦労した。S3などで画像をそのまま保存しても良かったが、学習速度の観点からデータ化することにした。当初30枚程度一度に送ろうとしていたが、データ量が多すぎ断念。1枚でも多すぎるというお言葉をサーバーからいただいた。データ型をTensol => Array => Stringにすることでなんとか解決。しんどかった...。

画像をDBから取り学習する

入力されたラベルを用いてDBから画像のリストを取得する。
ラベルと紐付いていないデータも同時に取得し、Fakeデータとして用いる。すでにベクトルに変換されているため、実際には出力層だけな簡単なモデルで学習している。

server.js
  app.post('/api/<hide>', async (req, res) => {
    if (!req.body.email) {
      return res.json({message: "email was not provided"})
    }
    const AllImage = await User.find({email: req.body.email})

    const fakeImages = await User.find({email: {'$ne': req.body.email }});
    return res.json({images: AllImage, fake_images: fakeImages})
  })
画像の取得
fetchImages () {
    return new Promise(async (solve, reject) => {
      const { email } = this.state
      const params = { email }
      const res = await axios.post('/api/<hide>', params)
      const image_string = res.data.images.map((item) => item.x_data)
      const fake_image_string = res.data.fake_images.map((item) => item.x_data)

      image_string.forEach((value) => {
        const image_tensor = tf.tensor1d(value.split(','))
        controllerDataset.addExample(image_tensor.reshape([1, 7, 7, 256]), 1)
      })

      fake_image_string.forEach((value) => {
        const image_tensor = tf.tensor1d(value.split(','))
        controllerDataset.addExample(image_tensor.reshape([1, 7, 7, 256]), Math.floor(Math.random() * 9 ) + 2)
      })
      solve()
    })    
  }
モデルの学習とコンパイル
      model = tf.sequential({
        layers: [
          tf.layers.flatten({inputShape: [7, 7, 256]}),
          tf.layers.dense({
            units: 100,
            activation: 'relu',
            kernelInitializer: 'varianceScaling',
            useBias: true
          }),
          tf.layers.dense({
            units: 10,
            kernelInitializer: 'varianceScaling',
            useBias: false,
            activation: 'softmax'
          })
        ]
      })

      const optimizer = tf.train.adam(0.0001);
      model.compile({optimizer: optimizer, loss: 'categoricalCrossentropy'});

 出力層だけの簡単なモデル。Google様は偉大です。

学習してモデルの作成
     const batchSize = Math.floor(controllerDataset.xs.shape[0] * 0.4)
      if (!(batchSize > 0)) throw new Error(`Batch size is 0 or NaN. Please choose a non-zero fraction.`)

      model.fit(controllerDataset.xs, controllerDataset.ys, {
        batchSize,
        epochs: 20,
        callbacks: { onBatchEnd: async (batch, logs) => await tf.nextFrame() }
      })

1を正解ラベルとして学習。そのほかのfakeデータは2~9の間でランダムにラベル付け。ちなみに学習の際にはTraining...と表示するようにしているのだが、目では確認できないほどに高速で学習が終わる。

モデルを用いてログインできるかどうかの判定

LogIn画面で画像を撮影し、判定する。
再度Googleのモデルを用いて500次元のベクトルに変換し、先ほど作成したモデルを用いると1~10の間でラベルが出る。このうち正解ラベルが選ばれた場合にのみログインが成功したとみなす。

front.js

  loadMobilenet () {
    return new Promise(async (solve, reject) => {
      const localMobilenet = await tf.loadModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')
      const layer = localMobilenet.getLayer('conv_pw_13_relu');
      solve(tf.model({inputs: localMobilenet.inputs, outputs: layer.output}))
    })
  }

  async predicting () {
    while (true) {
      const classId = await this.predict()
      if (classId === 1) {
        resultMessage.innerHTML = "Successed"
      } else {
        resultMessage.innerHTML = "Failed"

      }
    }
  }

  predict () {
    return new Promise((solve, reject) => {
      setTimeout(async () => {
        const predictedClass = tf.tidy(() => {
          const img = webcam.capture();
          const activation = mobilenet.predict(img);
          const predictions = model.predict(activation);
          return predictions.as1D().argMax();
        });
        const classId = (await predictedClass.data())[0];
        solve(classId)
      }, 1000);
    })
  }

1秒ごとに画像を読み込み、モデルで判定させている。
やったぜ。

感想

いやはや画像の保存が大変大変。なんかもっと上手く行く方法とかありそう。教えて偉い人。

8
8
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8