Help us understand the problem. What is going on with this article?

CNNで手書き文字を活字に変換

趣旨

手書き文字を活字に変換してみます。
分類によって文字を特定し活字を出力することもできますが、ここではあくまで画像から画像への変換を行います。
(と言いながら、実は今回の場合、内部的には分類とあまり変わらないとも考えられます。後述)

具体的には、通常のCNNの出力を画像にして、画像から画像への変換を行います。
出力が画像である点以外は、よくある画像分類を行うCNNと変わりません。
そのようなシンプルな教師あり学習で、うまくいくかどうか試してみます。

※ 簡易的に色々と試してみるのが目的であり、最適な手法を示すものではありません。

※ import

from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm_notebook as tqdm

from keras.layers import Reshape, Conv2D
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, load_model
from keras.applications.resnet50 import ResNet50
import keras.optimizers as optimizers
import keras.losses as losses
import keras.callbacks

# EMNIST
import emnist

データセット - EMNIST

EMNISTというデータセットを使います。
名前からうかがえるように、MNISTの拡張です。
数字だけでなくアルファベットも含みます。

今回はpipでemnistパッケージをインストールします。

!pip install emnist
import emnist

以下のような異なる分割方法を選択することができます。
byclass, bymerge, balanced, letters, digits, mnist

分割方法 データ数 クラス数 内容
byclass 814,255 62 [0-9]、[a-z]、および[A-Z]の62クラスに分類された文字を含む。 各クラスのデータ数はバラバラ。
bymerge 814,255 47 byclassのうち、特定の文字の大文字小文字が同じクラスにマージされている。 マージ対象の文字は[C, I, J, K, L, M, O, P, S, U, V, W, X, Y, Z] 各クラスのデータ数はバラバラ。
balanced 131,600 47 bymergeの各クラスのデータ数を揃えている。
letters 145,600 37 アルファベットのみ。 全てのクラスの大文字小文字をマージしている。 各クラスの件数は揃えている。
digits 280,000 10 数字のみ。 各クラスの件数は揃えている。
mnist 70,000 10 MNIST相当。

このうち、数字とアルファベットを含み、大文字と小文字のクラスが区別された計62クラスの byclassでロードします。

# EMNISTをロード
train_X, train_Y = emnist.extract_training_samples('byclass')
test_X, test_Y = emnist.extract_test_samples('byclass')

# 各クラスを文字に直したもの
chars = [chr(i) for i in range(48, 48+10)] + [chr(i) for i in range(65, 65+26)] + [chr(i) for i in range(97, 97+26)]
num_classes = len(chars)
print(chars)
#['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

さて、画像を出力とする教師あり学習なので、正解値としての画像が必要です。
EMNISTのラベルから文字を取得し、サンセリフ体のフォントで画像へ描画し、ndarrayに変換します。
出力画像のサイズ、形式も入力画像に合わせています。

def get_char_size(char, font):
  testImg = Image.new('RGB', (1, 1))
  testDraw = ImageDraw.Draw(testImg)
  return testDraw.textsize(char, font)

def get_char_size_max(font):
  max_width, max_height = 1, 1
  for char in chars:
    width, height = get_char_size(char, font)
    max_width, max_height = max(width, max_width), max(height, max_height)
  return max_width, max_height

def string_to_img_array(text):
  """
  文字列を画像に描画しnumpy配列に変換する
  """

  # font_manager.findSystemFonts(fontpaths=None, fontext='ttf')
  font_path = '/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf'
  fontsize = 28
  font = ImageFont.truetype(font_path, fontsize)

  # width, height = get_char_size_max(font)
  # print(width, height)
  width, height = 32, 32

  colorText = "black"
  colorBackground = "white"
  colorOutline = "white"

  img = Image.new('L', (width, height), colorBackground)
  d = ImageDraw.Draw(img)
  d.text((1, 1), text, fill=colorText, font=font)
  d.rectangle((0, 0, width, height), outline=colorOutline)
  img_array = np.array(img)

  return img_array

のちほど改めて触れますが、今回モデル自体は簡便に済ますため、Kerasに組み込まれている ResNet50 を利用します。ただ、ResNet50は、最低でもサイズが32×32、チャネル数が3の画像データしか受け取ってくれないため、リサイズやチャネルの追加を行っています。

def convet_x_for_resnet50(X):
  X = X.reshape(X.shape + (1,))
  X = np.array([cv2.resize(x, dsize=(32, 32), interpolation=cv2.INTER_CUBIC) for x in X])
  return np.stack((X,) * 3, axis=-1)

def convet_y_for_resnet50(Y):
  # 正解画像の作成もここで実行
  Y = np.array([string_to_img_array(chars[y]) for y in tqdm(Y)] )
  return np.stack((Y,) * 3, axis=-1)

# Keras組み込みのResNet50に入力できるよう便宜的に変換(+正規化)
train_X = convet_x_for_resnet50(train_X) / 255
train_Y = convet_y_for_resnet50(train_Y) / 255
test_X = convet_x_for_resnet50(test_X) / 255
test_Y = convet_y_for_resnet50(test_Y) / 255

もちろん、入力データに合わせてモデルを自分で構築する場合は必要ありません。

入力画像と正解画像を並べてみます。

def reshape_for_sample_show(data):
  data = data.reshape(-1,8,32,32,3)
  data = data.transpose(0,2,1,3,4)
  return data.reshape(8*32,8*32,3)

sample_X = reshape_for_sample_show(train_X[:64])
sample_Y = reshape_for_sample_show(train_Y[:64])
plt.figure(figsize=(16,8))
plt.subplot(1, 2, 1)
plt.title('Input')
plt.imshow(sample_X)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Correct')
plt.imshow(sample_Y)
plt.axis('off')

emnist.png

モデル

先に述べたように、Kerasの組み込みのResNet50を流用します。
もっとも、ImageNetの分類とは入力サイズもデータもタスクもかなり異なるので、学習済みの重みは使用せず、イチから学習します。
ResNet50の引数のinclude_topをFalseにすることで、出力層側のGAPと全結合層を外します。
出力層として畳み込み層(Conv2D)を追加することで、入力画像と同じサイズの画像を出力するようにします。
(ちなみに実は出力層よりかなり手前で特徴マップの縦横のサイズが1になるという意味で、二次元構造は一度なくなっています)

#適宜調整
batch_size = 180
epochs = 120#36#96
lr = 0.0002

※ 実際は訓練の中断や、訓練途中の重みの読込などを行っており、上記のハイパーパラメータによる1回の実行で終えているわけではありません。

base_model = ResNet50(weights=None, include_top=False, input_shape=train_X.shape[1:])
model = base_model.output
model = Reshape((32, 32, 2))(model)

# 出力層
model = Conv2D(3, (32, 32), padding='same', activation='linear')(model)

model = Model(inputs=base_model.input, outputs=model)

ピクセルの明るさについての回帰なので、損失関数は平均二乗誤差、重みは訓練し直すのでオプティマイザーはAdamとします。

model.compile(loss=losses.mean_squared_error,
                optimizer=optimizers.Adam(lr=lr),
                metrics=['mae', 'mse'])

訓練・評価

# 必要ならチェックポイント設定
checkpoint = keras.callbacks.ModelCheckpoint(filepath = 'checkpoint.h5', monitor='val_mean_squared_error', verbose=1, save_best_only=True, mode='auto')
cbs = [checkpoint]

# 訓練
history = model.fit(train_X, train_Y, batch_size=batch_size, epochs=epochs,
                    verbose=1, validation_split=0.1, callbacks=cbs)

model.save('saved_model.h5')

# 評価
score = model.evaluate(test_X, test_Y, verbose=0)
print('Test loss:', score[0])
print('Test mae:', score[1])


hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch

plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error')
plt.plot(hist['epoch'], hist['mean_absolute_error'],
          label='Train Error')
plt.plot(hist['epoch'], hist['val_mean_absolute_error'],
          label = 'Val Error')
plt.legend()

plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Square Error')
plt.plot(hist['epoch'], hist['mean_squared_error'],
          label='Train Error')
plt.plot(hist['epoch'], hist['val_mean_squared_error'],
          label = 'Val Error')
plt.legend()
plt.show()

結果

実際にテストデータの一部を変換してみます。

predicted = model.predict(test_X[:64])

sample_input = reshape_for_sample_show(test_X[:64])
sample_correct = reshape_for_sample_show(test_Y[:64])
sample_predicted = reshape_for_sample_show(predicted)

plt.figure(figsize=(18,6))

# 入力
plt.subplot(1, 3, 1)
plt.title('Input')
plt.imshow(sample_input)
plt.axis('off')

# 出力
plt.subplot(1, 3, 2)
plt.title('Predicated')
plt.imshow(sample_predicted)
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Correct')
plt.imshow(sample_correct)
plt.axis('off')

左が入力画像、右が出力画像です。

prediction正解以外.png

それっぽい感じになりました。
人間が見ても紛らわしいものは、正解と比べると間違えてたりします。
迷ったところは、うっすらと出るのがなんだかかわいいです(笑)

未知の文字(カタカナ)を入力すると…

ここで学習データにはない手書きの「ア」を入力すると…

katakana = np.expand_dims(cv2.resize(np.array(Image.open('ア.png')), dsize=(32, 32), interpolation=cv2.INTER_CUBIC), axis=0) / 255
predicted = model.predict(katakana)
plt.figure(figsize=(6,6))
plt.imshow(predicted[0])
plt.axis('off')

アを並べて.png

はい、既知のアルファベットに変換しようとしています。
汎用的なスタイル変換を学習できているわけではないということです。

今回は答えの画像の種類が62パターン固定なので、モデルの気持ちになって考えると、62パターンのうちどのパターンなのかわかれば、あとは固定の変換をすれば目的の画像が得られます。
なので、広い意味ではラベルの形が違うだけの分類であるという見方もできるわけです。

まとめ

いらないことも取り入れながら、出力を画像にして、楽しく手書き文字から活字への変換ができました。

では、たとえば日本語で使われる多くの文字を対象としつつ、未知の文字を入力しても活字っぽく変換するにはどうすべきか。
GAN周りの技術も取り入れつつ、次回試してみたいと思います。

ソースコード
https://gist.github.com/shindooo/d9c2ba9cd35720400ec1bdfd531fcabd

shindooo
投稿内容は私個人の見解であり、所属する組織の公式見解ではありません。
https://www.kronos.jp/
kronos-jp
AI開発・WEB開発・システム開発・Android開発・iOS開発・IT研修・トレーニング・新入社員研修などを行う企業です。
https://www.kronos.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした