0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ブラウザで手書き数字を描いてAIが判定するWebアプリを作る【Flask + PyTorch】

Posted at

はじめに

ブラウザで数字を描くとAIがリアルタイムで判定してくれるWebアプリを作ってみました。
Flask初学者が躓いたポイントや解決方法も含めて詳しく解説します。

環境

  • Windows11
  • WSL2
  • Python 3.8+
  • Flask初学者

ゴールイメージ

  1. ブラウザでCanvas上に手書きで数字を描画
  2. 「判定」ボタンを押す
  3. AIが判定した結果が即座に表示される

全体構成

パート 内容 使う技術
① APIサーバ Flaskで /predict エンドポイントを作る Python, Flask
② 推論処理 PyTorchで数字を判定 Python, PyTorch
③ Webフロント HTML + JavaScript(Canvas)で描画UI HTML, JS
④ 連携確認 JSからFlaskに画像送信してAIが答える fetch API

実装手順

ステップ①:Flaskインストール

まずWSL2でFlaskをインストールします。
※FlaskはpythonでWebアプリやAPIサーバを作るためのフレームワークです。

pip install flask

確認:

python3 -m flask --version

出力結果

~$ python3 -m flask --version
Python 3.8.10
Flask 3.0.3
Werkzeug 3.0.6

ステップ②:最小のFlaskサーバを作る

app.py を新しく作成:
../mnist/mnist_model.pthは前回の記事で作成した学習モデルを指します

ちなみに、学習済モデルは重み(学習済パラメータ)のみが保存されているため、
モデル定義をして各層の構造を教えてあげる必要があります。

app.py
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageOps
import io
import base64

app = Flask(__name__)

# --- モデル定義 ---
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# 学習済みモデルを読み込み(重みパラメータのみ)
model.load_state_dict(torch.load("../mnist/mnist_model.pth", weights_only=True))
model.eval()  # 推論モードに切り替え(学習時の更新を無効化)

# 画像前処理の設定
transform = transforms.Compose([
    transforms.Grayscale(),      # グレースケール化(1チャンネル)
    transforms.Resize((28, 28)), # サイズを学習データと合わせる
    transforms.ToTensor(),       # Tensor形式に変換
])

# --- Flaskルーティング設定 ---
@app.route("/predict", methods=["POST"])
def predict():
    # Base64形式(文字列化された画像)を取得
    data = request.json["image"]

    # "data:image/png;base64,..." の後半のみを抽出してデコード
    image_data = base64.b64decode(data.split(",")[1])

    # io.BytesIOでバイナリデータを“ファイルのように”扱い、画像として開く
    img = Image.open(io.BytesIO(image_data))

    # グレースケール変換 + 白黒反転(MNISTと同じ形式に揃える)
    img = ImageOps.invert(img.convert("L"))

    # Tensor変換+バッチ次元を追加(形状を [1, 1, 28, 28] に)
    img_tensor = transform(img).unsqueeze(0)

    # 推論処理(学習せず結果だけ算出)
    with torch.no_grad():
        outputs = model(img_tensor)
        predicted = torch.argmax(outputs, 1).item()  # 最も確信度の高いクラスを取得

    # 結果をJSON形式で返す
    return jsonify({"prediction": int(predicted)})

# --- サーバ起動 ---
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)


ステップ③:HTMLで描画&送信するUIを用意

手書き入力とAI連携のためのHTMLを作成します。
app.pyと同じフォルダにindex.htmlを作成します。

index.html
<!DOCTYPE html>
<html lang="ja">
<head>
  <meta charset="UTF-8">
  <title>手書き数字認識</title>
  <style>
    canvas { border: 1px solid #333; background: white; }
    button { margin-top: 10px; }
  </style>
</head>
<body>
  <h2>手書き数字を描いてみよう</h2>
  <canvas id="canvas" width="200" height="200"></canvas><br>
  <button id="predictBtn">予測する</button>
  <p id="result"></p>

  <script>
    const canvas = document.getElementById("canvas");
    const ctx = canvas.getContext("2d");
    ctx.lineWidth = 12;
    ctx.lineCap = "round";

    let drawing = false;
    canvas.addEventListener("mousedown", () => drawing = true);
    canvas.addEventListener("mouseup", () => drawing = false);
    canvas.addEventListener("mousemove", draw);

    function draw(e) {
      if (!drawing) return;
      ctx.strokeStyle = "black";
      ctx.lineTo(e.offsetX, e.offsetY);
      ctx.stroke();
      ctx.beginPath();
      ctx.moveTo(e.offsetX, e.offsetY);
    }

    document.getElementById("predictBtn").onclick = async () => {
      const dataURL = canvas.toDataURL("image/png");
      const res = await fetch("/predict", {
        method: "POST",
        headers: { "Content-Type": "application/json" },
        body: JSON.stringify({ image: dataURL })
      });
      const result = await res.json();
      document.getElementById("result").innerText = "予測された数字: " + result.prediction;
    };
  </script>
</body>
</html>

ステップ④:サーバを起動してアクセス!

python3 app.py

出力例:

 * Running on http://0.0.0.0:5000

出力結果

~/pytorch-playground/webapp$ python3 app.py
 * Serving Flask app 'app'
 * Debug mode: off
WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://<内部PCのIPアドレス>:5000
Press CTRL+C to quit

動作確認

以下のURLをブラウザで開きます。
http://localhost:5000/static/index.html

数字を描いて「予測する」ボタンを押すと、
AIが /predict に画像を送り、予測結果が下に出ました!

image.png

何描いても3になる問題

image.png

何度か試しましたが、予測数字は3にしかなりませんでした。なぜ

改善策①:学習エポックを増やす

まずは簡単にできる、学習エポックを増やしてみます。

for epoch in range(10):  # ← 10〜15に増やす

10に変更し、以下を実行。

python3 mnist.py

※前回記事参照:

実行結果

~/pytorch-playground/mnist$ python3 mnist.py
Epoch 1:loss0.0653
Epoch 2:loss0.0871
Epoch 3:loss0.0097
Epoch 4:loss0.1053
Epoch 5:loss0.1043
Epoch 6:loss0.0354
Epoch 7:loss0.0106
Epoch 8:loss0.0202
Epoch 9:loss0.0028
Epoch 10:loss0.0336
学習完了!

image.png

変わらず何を書いても推測結果は4でした。

改善策③:推論前の確認

学習がうまくいってるかをチェックするため、学習後に↓を追加しました。

correct = 0
total = 0
for images, labels in train_loader:
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Training Accuracy: {100 * correct / total:.2f}%")

これで「学習済みモデルの精度」が見える。
70〜80%超えたらまず合格、90%超えたらWebでも安定して動くようになるとのことです。

:~/pytorch-playground/mnist$ python3 mnist.py
Epoch 1:loss0.2892
Epoch 2:loss0.2797
Epoch 3:loss0.1063
Epoch 4:loss0.0934
Epoch 5:loss0.0698
Epoch 6:loss0.1539
Epoch 7:loss0.0228
Epoch 8:loss0.0441
Epoch 9:loss0.0425
Epoch 10:loss0.0022
学習完了!
Training Accuracy: 99.52%

99%超えとのこと。え???なぜ動かない???

デバックを追加

原因を探るため、バッチ次元を追加後にデバッグ処理を追加しました。

app.py
# デバッグ: 画像を保存して確認
  img.save("debug_image.png")
  print(f"画像サイズ: {img.size}")
  print(f"画像モード: {img.mode}")
  print(f"テンソル形状: {img_tensor.shape}")
  print(f"テンソル値の範囲: min={img_tensor.min():.4f}, max={img_tensor.max():.4f}")
# 推論モードに切り替え
  with torch.no_grad():
    outputs = model(img_tensor)
    print(f"モデル出力: {outputs}")
    print(f"各クラスのスコア: {outputs[0].tolist()}")
    # AIが最も自信のある数字を取得
    predicted = torch.argmax(outputs, 1).item()
    print(f"予測: {predicted}")

出力結果

画像サイズ: (200, 200)
画像モード: L
テンソル形状: torch.Size([1, 1, 28, 28])
テンソル値の範囲: min=1.0000, max=1.0000
モデル出力: tensor([[ -25.7221,  -28.7170,   17.4366,    2.0993, -103.4495,   15.5843,
          -14.0252,   -2.0827,   -9.3242,  -56.5773]])
各クラスのスコア: [-25.722137451171875, -28.71704864501953, 17.436573028564453, 2.0993316173553467, -103.44945526123047, 15.584332466125488, -14.025151252746582, -2.082690954208374, -9.32417106628418, -56.57730484008789]
予測: 2

画像のすべてのピクセル値が1.0(完全に白)になっていました!!!

これは、キャンバスから取得した画像がRGBA形式で、アルファチャンネルの処理が正しくないことが原因。

以下を追記

app.py
# RGBAの場合、白背景でRGBに変換
  if img.mode == 'RGBA':
    rgb_img = Image.new('RGB', img.size, (255, 255, 255))
    rgb_img.paste(img, mask=img.split()[3])  # アルファチャンネルをマスクとして使用
    img = rgb_img

サーバー再起動後のデバック

画像サイズ: (200, 200)
画像モード: L
テンソル形状: torch.Size([1, 1, 28, 28])
テンソル値の範囲: min=0.0000, max=0.9961
モデル出力: tensor([[-14.2575, -12.3142,  -2.3615,   2.1547, -12.1644,  -2.3196, -11.8768,
          -7.0769,   7.7508,  -3.2284]])
各クラスのスコア: [-14.257462501525879, -12.314240455627441, -2.36152720451355, 2.1546876430511475, -12.164447784423828, -2.319572687149048, -11.876819610595703, -7.076918601989746, 7.750839710235596, -3.2284023761749268]
予測: 8

image.png

推測できました!

最終的な全体のコード

最終的なapp.py
app.py
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageOps
import io
import base64

app = Flask(__name__)

# モデル定義
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)
# 学習済モデルを読み込み
model.load_state_dict(torch.load("../mnist/mnist_model.pth", weights_only=True))
model.eval()

transform = transforms.Compose([
  transforms.Grayscale(),
  transforms.Resize((28, 28)),
  transforms.ToTensor(),
])

# Flaskのルーティング
@app.route("/predict", methods=["POST"])
def predict():
  data = request.json["image"] #Base64形式(文字列に変換された画像)で受け取る
  # 純粋な画像部分だけを抜き出す
  image_data = base64.b64decode(data.split(",")[1])
  img = Image.open(io.BytesIO(image_data)) # io.ByteIOはバイナリデータをファイルっぽく扱えるようにしている

  # RGBAの場合、白背景でRGBに変換
  if img.mode == 'RGBA':
    rgb_img = Image.new('RGB', img.size, (255, 255, 255))
    rgb_img.paste(img, mask=img.split()[3])  # アルファチャンネルをマスクとして使用
    img = rgb_img

  # グレースケールに変換後白黒反転
  img = ImageOps.invert(img.convert("L"))
  # バッチ次元を追加
  img_tensor = transform(img).unsqueeze(0)

  # デバッグ: 画像を保存して確認
  img.save("debug_image.png")
  print(f"画像サイズ: {img.size}")
  print(f"画像モード: {img.mode}")
  print(f"テンソル形状: {img_tensor.shape}")
  print(f"テンソル値の範囲: min={img_tensor.min():.4f}, max={img_tensor.max():.4f}")

  # 推論モードに切り替え
  with torch.no_grad():
    outputs = model(img_tensor)
    print(f"モデル出力: {outputs}")
    print(f"各クラスのスコア: {outputs[0].tolist()}")
    # AIが最も自信のある数字を取得
    predicted = torch.argmax(outputs, 1).item()
    print(f"予測: {predicted}")

  return jsonify({"prediction": int(predicted)})
# サーバー起動
if __name__=="__main__":
  app.run(host="0.0.0.0", port=5000)

おわりに

推測ができない原因もAI(Claude Code)に解析してもらいました。
AIすごい。
AIを用いてAIを学んでいけたら楽しそうだと思いました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?