LoginSignup
2
2

More than 3 years have passed since last update.

かなの手書き文字認識を作ってみた Part 1/3 まずは MNIST から

Last updated at Posted at 2020-11-07

概要

GUI でかなを入力し、予め機械学習により訓練して作成したモデルにより、文字を検出させようとしました。

まず MNIST で CNN の感触と精度を確認し、次に実際にかなのデータを与えて学習させ、最後に GUI と連携させます。

次回 (2/3): https://qiita.com/tfull_tf/items/968bdb8f24f80d57617e
次次回 (3/3): https://qiita.com/tfull_tf/items/d9fe3ab6c1e47d1b2e1e

全体のコードは次の場所にあります。
https://github.com/tfull/character_recognition

MNIST とモデル構築

独自のモデルを構築し、よく使われる手書き数字のデータセット MNIST で、 train, test を行い、どれくらいの精度が出るのかを試します。

MNIST は 28x28 のグレースケールデータのため、 (channel, width, height) = (1, 28, 28) として入力します。数字は 0 ~ 9 のため、10通りの分類先があり、10個の確率を出力するようにします。

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(0.3)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(12 * 12 * 32, 128)
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.linear2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu3(x)
        x = self.dropout2(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

2つの畳み込み層とその後のプーリング層を経由し、1次元に変換して全結合層を2つ通す。活性化関数は ReLU で、途中で過学習防止のドロップアウト層を入れる、というのがモデルの概形です。

データ取得

import torchvision

download_flag = not os.path.exists(data_directory + "/mnist")

mnist_train = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = True,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

mnist_test = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = False,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

MNIST のデータをローカルに保存して、それを使うようにします。
data_directory は定義しておいて、ディレクトリがなければダウンロードをするようにします。そうすることで、最初の1回だけダウンロードをするようにしました。

学習準備

import torch
import torch.optim as optim

train_loader = torch.utils.data.DataLoader(mnist_train,  batch_size = 100,  shuffle = True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle = False)

model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

DataLoader を使って、順番にデータを取り出せるようにします。

モデル、誤差関数、最適化アルゴリズムを設定します。交差エントロピー誤差、 Adam を採用しました。

訓練

n_epoch = 2

model.train()

for i_epoch in range(n_epoch):
    for i_batch, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        print("epoch: {}, train: {}, loss: {}".format(i_epoch + 1, i_batch + 1, loss.item()))

画像データ (inputs) をモデル (model) に与え、出力 (output) と正解データ (labels) を比較して誤差を求め、逆伝搬させるという、学習を行う際の一連の操作をループ内で行っています。

各データを1回与えるだけでは学習に足りないと思うので、エポック数 (n_epoch) を2とし、各データを n_epoch 回与えて学習させています。エポック数は、私の経験ですが、2~3くらいがちょうどよいのではと思っています。データの数にもよると思いますが。

評価

correct_count = 0
record_count = 0

model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, prediction = torch.max(outputs.data, 1)
        judge = prediction == labels
        correct_count += int(judge.sum())
        record_count += len(judge)

print("Accuracy: {:.2f}%".format(correct_count / record_count * 100))

モデルに画像の数値データ (inputs) を入力し、出てきた10つの確率のうち最も高いものを選択結果 (prediction) としています。それが正解データ (labels) と一致しているかを比較して True/False を返し、全体の数 (record_count) に対する True の個数 (correct_count) を計算して正答率としています。

結果と考察

結果は、複数回を平均して、約 97% になりました。

正答率の値としては高いとは思いますが、100回に3回は失敗しています。人間がこれを許容できるかはまた別の問題になってくると思います。しかし、 MNIST の画像データの中には、人間が見ても判別しづらい汚い文字があったりするので、そういう意味では 3% の間違いは仕方ないかもしれません。

MNIST は 0 ~ 9 の10択ですが、かなについてはひらがなとカタカナで100以上ありますから、分類も難しくなり、正答率がもっと落ちることは覚悟しないといけないでしょう。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2