LoginSignup
4
3

More than 3 years have passed since last update.

React + TensorFlow.jsで手書き数字認識アプリを作ってみた

Last updated at Posted at 2019-12-30

はじめに

サークルの講習会で、kivyとTensorFlow?を使って手書きの数字を認識するアプリを作っている先輩がいたので、これをReactでやってみたいなと思ったので、今回やってみました。このアプリを作る中で、こちらのサイトを非常に参考にさせていただきました。また、僕は機械学習についてあまり知識がないので、今回はpythonを全く書きません。機会があれば、モデルの作成からやってみたいと思います。

完成物

tegaki2.gif

ソースコード:https://github.com/Alesion30/predict-number-app

インストールしたパッケージ

TensorFlow.js: https://www.tensorflow.org/js/
react-signature-canvas: https://www.npmjs.com/package/react-signature-canvas
material-ui: https://material-ui.com/

"@material-ui/core": "^4.8.1",
"@tensorflow/tfjs": "^0.12.6",
"react-signature-canvas": "^1.0.3",

※ tensorflow/tfjsは上記のバージョンに合わせてください。

ファイル構造

create-react-appコマンドで作成したプロジェクトをベースに話を進めます。以下のファイル構造は、編集または新規作成したファイルのみ書いています。

src-
   |-components
   |      |-Accuracy.js
   |      |-AccuracyTable.js
   |
   |-App.js
   |-App.css
components/Accuracy.js
import React from "react";

const Accuracy = props => {
  const { no, content } = props;

  return (
    <tr>
      <th>{no}</th>
      <td className="accuracy" data-row-index={`${no}`}>
        {content}
      </td>
    </tr>
  );
};

export default Accuracy;
components/AccuracyTable.js
import React from "react";
import Accuracy from "./Accuracy";

const AccuracyTable = () => (
  <table className="table">
    <thead>
      <tr>
        <th>数字</th>
        <th>精度</th>
      </tr>
    </thead>
    <tbody>
      <Accuracy no={0} content="-" />
      <Accuracy no={1} content="-" />
      <Accuracy no={2} content="-" />
      <Accuracy no={3} content="-" />
      <Accuracy no={4} content="-" />
      <Accuracy no={5} content="-" />
      <Accuracy no={6} content="-" />
      <Accuracy no={7} content="-" />
      <Accuracy no={8} content="-" />
      <Accuracy no={9} content="-" />
    </tbody>
  </table>
);

export default AccuracyTable;
App.js
import React from "react";
import "./App.css";
import * as tf from "@tensorflow/tfjs";
import SignatureCanvas from "react-signature-canvas";
import { Button } from "@material-ui/core";
import AccuracyTable from "./components/AccuracyTable";

class App extends React.Component {
  constructor() {
    super();
    this.state = {
      is_loading: "is-loading",
      model: null,
      maxNumber: null,
      maxScore: null
    };
    this.onRef = this.onRef.bind(this);
    this.getImageData = this.getImageData.bind(this);
    this.getAccuracyScores = this.getAccuracyScores.bind(this);
    this.predict = this.predict.bind(this);
    this.reset = this.reset.bind(this);
  }

  componentDidMount() {
    tf.loadModel(
      "https://raw.githubusercontent.com/tsu-nera/tfjs-mnist-study/master/model/model.json"
    ).then(model => {
      this.setState({
        is_loading: "",
        model
      });
    });
  }

  onRef(ref) {
    this.signaturePad = ref;
  }

  getAccuracyScores(imageData) {
    const scores = tf.tidy(() => {
      const channels = 1;
      let input = tf.fromPixels(imageData, channels);
      input = tf.cast(input, "float32").div(tf.scalar(255));
      input = input.expandDims();
      return this.state.model.predict(input).dataSync();
    });
    return scores;
  }

  getImageData() {
    return new Promise(resolve => {
      const context = document.createElement("canvas").getContext("2d");
      const image = new Image();
      const width = 28;
      const height = 28;

      image.onload = () => {
        context.drawImage(image, 0, 0, width, height);
        const imageData = context.getImageData(0, 0, width, height);

        for (let i = 0; i < imageData.data.length; i += 4) {
          const avg =
            (imageData.data[i] +
              imageData.data[i + 1] +
              imageData.data[i + 2]) /
            3;
          imageData.data[i] = avg;
          imageData.data[i + 1] = avg;
          imageData.data[i + 2] = avg;
        }
        resolve(imageData);
      };

      image.src = this.signaturePad.toDataURL();
    });
  }

  predict() {
    this.getImageData()
      .then(imageData => this.getAccuracyScores(imageData))
      .then(accuracyScores => {
        const maxAccuracy = accuracyScores.indexOf(
          Math.max.apply(null, accuracyScores)
        );
        const elements = document.querySelectorAll(".accuracy");
        elements.forEach(el => {
          el.parentNode.classList.remove("is-selected");
          const rowIndex = Number(el.dataset.rowIndex);
          if (maxAccuracy === rowIndex) {
            el.parentNode.classList.add("is-selected");
          }
          el.innerText = Math.round(accuracyScores[rowIndex] * 1000) / 1000;
        });
        this.setState({
          maxNumber: maxAccuracy,
          maxScore: accuracyScores[maxAccuracy]
        });
        console.log(accuracyScores);
      });
  }

  reset() {
    this.setState({
      maxNumber: null
    });
    this.signaturePad.clear();
    const elements = document.querySelectorAll(".accuracy");
    elements.forEach(el => {
      el.parentNode.classList.remove("is-selected");
      el.innerText = "-";
    });
  }

  render() {
    let text = "数字を入力してください";
    if (this.state.maxNumber !== null) {
      if (this.state.maxScore > 0.999) {
        text = `この数字は確実に${this.state.maxNumber}です。`;
      } else if (this.state.maxScore > 0.9) {
        text = `この数字はほぼ間違いなく${this.state.maxNumber}です。`;
      } else if (this.state.maxScore > 0.5) {
        text = `この数字は多分${this.state.maxNumber}です。`;
      } else {
        text = `この数字は${this.state.maxNumber}かもしれないです。`;
      }
    }
    return (
      <div className="container">
        <h2>{text}</h2>
        <div className="canbas">
          <SignatureCanvas
            ref={this.onRef}
            minWidth={15}
            maxWidth={15}
            penColor="white"
            backgroundColor="black"
            canvasProps={{
              width: 420,
              height: 420,
              className: "sigCanvas"
            }}
            onEnd={this.predict}
          />
        </div>
        <div className="button">
          <Button variant="contained" onClick={this.reset}>
            reset
          </Button>
        </div>
        <AccuracyTable />
      </div>
    );
  }
}

export default App;
App.css
.container {
  margin-bottom: 120px;
  text-align: center;
}

.canbas {
  display: inline;
}

.button {
  display: block;
  margin-top: 20px;
  margin-bottom: 60px;
}

.table {
  display: inline;
  border-collapse: collapse;
  border-spacing: 0;
}

.table th, .table td {
  padding: 10px 0;
  width: 200px;
  text-align: center;
}

.table tr:nth-child(odd) {
  background-color: #eee;
}

.is-selected {
  color: red;
}

終わりに

PCでは、うまくいったのですが、スマホからアクセスするとうまくいきませんでした。今後の課題としては、スマホアプリ(Expo+ReactNative)にしてみたいなと思っています。あとは、機械学習を勉強してモデルの作成もいちからやってみたいです。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3