はじめに
ブラウザで数字を描くとAIがリアルタイムで判定してくれるWebアプリを作ってみました。
Flask初学者が躓いたポイントや解決方法も含めて詳しく解説します。
環境
- Windows11
- WSL2
- Python 3.8+
- Flask初学者
ゴールイメージ
- ブラウザでCanvas上に手書きで数字を描画
- 「判定」ボタンを押す
- AIが判定した結果が即座に表示される
全体構成
| パート | 内容 | 使う技術 |
|---|---|---|
| ① APIサーバ | Flaskで /predict エンドポイントを作る |
Python, Flask |
| ② 推論処理 | PyTorchで数字を判定 | Python, PyTorch |
| ③ Webフロント | HTML + JavaScript(Canvas)で描画UI | HTML, JS |
| ④ 連携確認 | JSからFlaskに画像送信してAIが答える | fetch API |
実装手順
ステップ①:Flaskインストール
まずWSL2でFlaskをインストールします。
※FlaskはpythonでWebアプリやAPIサーバを作るためのフレームワークです。
pip install flask
確認:
python3 -m flask --version
出力結果
~$ python3 -m flask --version
Python 3.8.10
Flask 3.0.3
Werkzeug 3.0.6
ステップ②:最小のFlaskサーバを作る
app.py を新しく作成:
※../mnist/mnist_model.pthは前回の記事で作成した学習モデルを指します
ちなみに、学習済モデルは重み(学習済パラメータ)のみが保存されているため、
モデル定義をして各層の構造を教えてあげる必要があります。
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageOps
import io
import base64
app = Flask(__name__)
# --- モデル定義 ---
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# 学習済みモデルを読み込み(重みパラメータのみ)
model.load_state_dict(torch.load("../mnist/mnist_model.pth", weights_only=True))
model.eval() # 推論モードに切り替え(学習時の更新を無効化)
# 画像前処理の設定
transform = transforms.Compose([
transforms.Grayscale(), # グレースケール化(1チャンネル)
transforms.Resize((28, 28)), # サイズを学習データと合わせる
transforms.ToTensor(), # Tensor形式に変換
])
# --- Flaskルーティング設定 ---
@app.route("/predict", methods=["POST"])
def predict():
# Base64形式(文字列化された画像)を取得
data = request.json["image"]
# "data:image/png;base64,..." の後半のみを抽出してデコード
image_data = base64.b64decode(data.split(",")[1])
# io.BytesIOでバイナリデータを“ファイルのように”扱い、画像として開く
img = Image.open(io.BytesIO(image_data))
# グレースケール変換 + 白黒反転(MNISTと同じ形式に揃える)
img = ImageOps.invert(img.convert("L"))
# Tensor変換+バッチ次元を追加(形状を [1, 1, 28, 28] に)
img_tensor = transform(img).unsqueeze(0)
# 推論処理(学習せず結果だけ算出)
with torch.no_grad():
outputs = model(img_tensor)
predicted = torch.argmax(outputs, 1).item() # 最も確信度の高いクラスを取得
# 結果をJSON形式で返す
return jsonify({"prediction": int(predicted)})
# --- サーバ起動 ---
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
ステップ③:HTMLで描画&送信するUIを用意
手書き入力とAI連携のためのHTMLを作成します。
app.pyと同じフォルダにindex.htmlを作成します。
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>手書き数字認識</title>
<style>
canvas { border: 1px solid #333; background: white; }
button { margin-top: 10px; }
</style>
</head>
<body>
<h2>手書き数字を描いてみよう</h2>
<canvas id="canvas" width="200" height="200"></canvas><br>
<button id="predictBtn">予測する</button>
<p id="result"></p>
<script>
const canvas = document.getElementById("canvas");
const ctx = canvas.getContext("2d");
ctx.lineWidth = 12;
ctx.lineCap = "round";
let drawing = false;
canvas.addEventListener("mousedown", () => drawing = true);
canvas.addEventListener("mouseup", () => drawing = false);
canvas.addEventListener("mousemove", draw);
function draw(e) {
if (!drawing) return;
ctx.strokeStyle = "black";
ctx.lineTo(e.offsetX, e.offsetY);
ctx.stroke();
ctx.beginPath();
ctx.moveTo(e.offsetX, e.offsetY);
}
document.getElementById("predictBtn").onclick = async () => {
const dataURL = canvas.toDataURL("image/png");
const res = await fetch("/predict", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ image: dataURL })
});
const result = await res.json();
document.getElementById("result").innerText = "予測された数字: " + result.prediction;
};
</script>
</body>
</html>
ステップ④:サーバを起動してアクセス!
python3 app.py
出力例:
* Running on http://0.0.0.0:5000
出力結果
~/pytorch-playground/webapp$ python3 app.py
* Serving Flask app 'app'
* Debug mode: off
WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.
* Running on all addresses (0.0.0.0)
* Running on http://127.0.0.1:5000
* Running on http://<内部PCのIPアドレス>:5000
Press CTRL+C to quit
動作確認
以下のURLをブラウザで開きます。
http://localhost:5000/static/index.html
数字を描いて「予測する」ボタンを押すと、
AIが /predict に画像を送り、予測結果が下に出ました!
何描いても3になる問題
何度か試しましたが、予測数字は3にしかなりませんでした。なぜ
改善策①:学習エポックを増やす
まずは簡単にできる、学習エポックを増やしてみます。
for epoch in range(10): # ← 10〜15に増やす
10に変更し、以下を実行。
python3 mnist.py
※前回記事参照:
実行結果
~/pytorch-playground/mnist$ python3 mnist.py
Epoch 1:loss0.0653
Epoch 2:loss0.0871
Epoch 3:loss0.0097
Epoch 4:loss0.1053
Epoch 5:loss0.1043
Epoch 6:loss0.0354
Epoch 7:loss0.0106
Epoch 8:loss0.0202
Epoch 9:loss0.0028
Epoch 10:loss0.0336
学習完了!
変わらず何を書いても推測結果は4でした。
改善策③:推論前の確認
学習がうまくいってるかをチェックするため、学習後に↓を追加しました。
correct = 0
total = 0
for images, labels in train_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Training Accuracy: {100 * correct / total:.2f}%")
これで「学習済みモデルの精度」が見える。
70〜80%超えたらまず合格、90%超えたらWebでも安定して動くようになるとのことです。
:~/pytorch-playground/mnist$ python3 mnist.py
Epoch 1:loss0.2892
Epoch 2:loss0.2797
Epoch 3:loss0.1063
Epoch 4:loss0.0934
Epoch 5:loss0.0698
Epoch 6:loss0.1539
Epoch 7:loss0.0228
Epoch 8:loss0.0441
Epoch 9:loss0.0425
Epoch 10:loss0.0022
学習完了!
Training Accuracy: 99.52%
99%超えとのこと。え???なぜ動かない???
デバックを追加
原因を探るため、バッチ次元を追加後にデバッグ処理を追加しました。
# デバッグ: 画像を保存して確認
img.save("debug_image.png")
print(f"画像サイズ: {img.size}")
print(f"画像モード: {img.mode}")
print(f"テンソル形状: {img_tensor.shape}")
print(f"テンソル値の範囲: min={img_tensor.min():.4f}, max={img_tensor.max():.4f}")
# 推論モードに切り替え
with torch.no_grad():
outputs = model(img_tensor)
print(f"モデル出力: {outputs}")
print(f"各クラスのスコア: {outputs[0].tolist()}")
# AIが最も自信のある数字を取得
predicted = torch.argmax(outputs, 1).item()
print(f"予測: {predicted}")
出力結果
画像サイズ: (200, 200)
画像モード: L
テンソル形状: torch.Size([1, 1, 28, 28])
テンソル値の範囲: min=1.0000, max=1.0000
モデル出力: tensor([[ -25.7221, -28.7170, 17.4366, 2.0993, -103.4495, 15.5843,
-14.0252, -2.0827, -9.3242, -56.5773]])
各クラスのスコア: [-25.722137451171875, -28.71704864501953, 17.436573028564453, 2.0993316173553467, -103.44945526123047, 15.584332466125488, -14.025151252746582, -2.082690954208374, -9.32417106628418, -56.57730484008789]
予測: 2
画像のすべてのピクセル値が1.0(完全に白)になっていました!!!
これは、キャンバスから取得した画像がRGBA形式で、アルファチャンネルの処理が正しくないことが原因。
以下を追記
# RGBAの場合、白背景でRGBに変換
if img.mode == 'RGBA':
rgb_img = Image.new('RGB', img.size, (255, 255, 255))
rgb_img.paste(img, mask=img.split()[3]) # アルファチャンネルをマスクとして使用
img = rgb_img
サーバー再起動後のデバック
画像サイズ: (200, 200)
画像モード: L
テンソル形状: torch.Size([1, 1, 28, 28])
テンソル値の範囲: min=0.0000, max=0.9961
モデル出力: tensor([[-14.2575, -12.3142, -2.3615, 2.1547, -12.1644, -2.3196, -11.8768,
-7.0769, 7.7508, -3.2284]])
各クラスのスコア: [-14.257462501525879, -12.314240455627441, -2.36152720451355, 2.1546876430511475, -12.164447784423828, -2.319572687149048, -11.876819610595703, -7.076918601989746, 7.750839710235596, -3.2284023761749268]
予測: 8
推測できました!
最終的な全体のコード
最終的なapp.py
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageOps
import io
import base64
app = Flask(__name__)
# モデル定義
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# 学習済モデルを読み込み
model.load_state_dict(torch.load("../mnist/mnist_model.pth", weights_only=True))
model.eval()
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor(),
])
# Flaskのルーティング
@app.route("/predict", methods=["POST"])
def predict():
data = request.json["image"] #Base64形式(文字列に変換された画像)で受け取る
# 純粋な画像部分だけを抜き出す
image_data = base64.b64decode(data.split(",")[1])
img = Image.open(io.BytesIO(image_data)) # io.ByteIOはバイナリデータをファイルっぽく扱えるようにしている
# RGBAの場合、白背景でRGBに変換
if img.mode == 'RGBA':
rgb_img = Image.new('RGB', img.size, (255, 255, 255))
rgb_img.paste(img, mask=img.split()[3]) # アルファチャンネルをマスクとして使用
img = rgb_img
# グレースケールに変換後白黒反転
img = ImageOps.invert(img.convert("L"))
# バッチ次元を追加
img_tensor = transform(img).unsqueeze(0)
# デバッグ: 画像を保存して確認
img.save("debug_image.png")
print(f"画像サイズ: {img.size}")
print(f"画像モード: {img.mode}")
print(f"テンソル形状: {img_tensor.shape}")
print(f"テンソル値の範囲: min={img_tensor.min():.4f}, max={img_tensor.max():.4f}")
# 推論モードに切り替え
with torch.no_grad():
outputs = model(img_tensor)
print(f"モデル出力: {outputs}")
print(f"各クラスのスコア: {outputs[0].tolist()}")
# AIが最も自信のある数字を取得
predicted = torch.argmax(outputs, 1).item()
print(f"予測: {predicted}")
return jsonify({"prediction": int(predicted)})
# サーバー起動
if __name__=="__main__":
app.run(host="0.0.0.0", port=5000)
おわりに
推測ができない原因もAI(Claude Code)に解析してもらいました。
AIすごい。
AIを用いてAIを学んでいけたら楽しそうだと思いました。



