LoginSignup
90
92

More than 3 years have passed since last update.

日本語OCRを作ったので解説してみる

Posted at

日本語OCRを作ったので詳しく解説してみる

GitHub↓で公開中。
https://github.com/tanreinama/OCR_Japanease

日本語OCRとは

文字通り日本語のOCRです。OCRとは、画像から文字を認識するプログラムです。
catch.png

前回の記事
https://qiita.com/tanreinama/items/e171449e66d5221afe7e

解説

使用するニューラルネットワーク

このOCRプログラムのメインは、基本的にはディープラーニングによって作成されたニューラルネットワークの実行です。

OCRに必要となるのは、文章領域・文字の検出用と、文字のクラス分類用の二つのニューラルネットワークです。

Center Line Detection

まず、OCRでは、画像中にある文字を全て取り出せば、それで良いわけではありません。

文字は繋がって文章となり、ひとまとまりの文章として認識することで初めて意味のある結果になります。

このOCRプログラムでは、文章領域の抽出と、文字の抽出を、一つのニューラルネットワークで行っているのが特徴です。

ここで何をやっているかを知るには、CenterNetについての知識があると望ましいでしょう。CenterNetでは、物体検出に際して、BoundingBoxではなく中心点(からのガウス分布)をターゲットに学習させます。

この、中心点の検出は、文字単体についてはとてもよく動くのですが、文章領域の抽出には不向きです。

しかし、文章は、文字が縦または横に繋がっているため、文字の配置には制約が存在しています。

そこで導入したのは、文章を、文字の中心線を繋げた線として認識するソリューションです。

CenterLineDetection

私はこの手法を「Center Line Detection」と呼んでいます。CenterNetの発想と同じく、中心線とそこからのガウス分布を学習させます。一度のニューラルネットワークの実行で、文字一つ一つの位置であるCenter PointとCenter Lineが同時に得られます。

プログラムコード

プログラムでは、misc/detection.pyにある、_detectNx関数がCenter Line Detectionのニューラルネットワークの実行を担当します。

def _detect1x(self, detector_model, gray_img_scaled):
    im = np.zeros((1,1,512,512))
    im[0,0,0:gray_img_scaled.shape[0],0:gray_img_scaled.shape[1]] = gray_img_scaled
    x = np.clip(im / 255, 0.0, 1.0).astype(np.float32)
    x = torch.tensor(x)
    dp = torch.nn.DataParallel(detector_model)
    if self.use_cuda:
        x = x.cuda()
        dp = dp.cuda()
    dp.eval()
    y = dp(x)

    hm_wd = ((y[0]['hm_wd'] + y[1]['hm_wd']) / 2).detach().cpu().numpy().reshape(128,128)
    hm_sent = ((y[0]['hm_sent'] + y[1]['hm_sent']) / 2).detach().cpu().numpy().reshape(128,128)
    hm_pos = ((y[0]['of_size'] + y[1]['of_size']) / 2).detach().cpu().numpy().reshape(2,64,64)
    del x, y
    if self.use_cuda:
        torch.cuda.empty_cache()
    return hm_wd, hm_sent, hm_pos

Pytorchのモデルからは、hm_wdhm_sentof_sizeを含むディクショナリが返されます。
これは、ニューラルネットワークに前回作ったHourgrassNetの亜種を使っているので、Hourglassの複数の場所から出力を取りだしているからです。Hourglassについては前回の解説を参照してください。

文章領域の抽出

Center Line Detectionが出来れば、次はそれを文章領域にします。

下は左から、入力画像と、適切なDPIになるよう解像度を調整したもの(ニューラルネットワークへの入力)、Center Pointの認識結果、Center Lineの認識結果です。

Center Line Detectionによって文字の繋がりが見えていることが解ります。

CenterLineDetectionの結果

割と複雑な画像でも、DPIさえ適切に設定してやれば、きちんと識別出来ます。
例として、私の書いた本の表紙を切り出して、入力してみます。readmeにあるように、モデルの種類を「font」に指定すると、そこそこ上手く動きます。

$ python3 ocr_japanease.py --model font testshot2.png

その際のCenter Line Detectionの結果が以下で、縦書きの部分もきちんとひと繋がりの文として認識していることが解ります。

CenterLineDetectionの結果

ちなみに、このCenter Line Detectionに使うニューラルネットワークは、512x512ピクセルの高解像度で入力します。なので、GPUメモリを大食いします。6GBのメモリでバッチサイズ4が限界という感じ。ニューラルネットワークのパラメーター数は少ないのですが、内部を通るデータの分が大きいです。
上記の本の表紙は、解像度が大きいので4分割してニューラルネットワークを実行しています(そのため1024x1024ピクセルでの認識)。
OCRプログラムでは、最大16分割まで分割して認識するので、2048x2048ピクセルまで対応しています。Center Line Detectionでは高解像度の画像を使いますが、後で文字をクラス分類するときには、BoundingBoxを切り出した低解像度の画像を使うので問題ありません。
misc/detection.pyにある、_detect4x関数と_detect16x関数では、バッチサイズ=4に領域を分割してニューラルネットワーク実行し、結果をマージしています。

文章領域のクラスタリング

そして、認識したCenter Lineを、OPTICSアルゴリズムによってクラスタリングします。

下は、クラスタリングの結果(一番左)と、その領域を内接矩形で囲ったものになります。
Center Line Detectionの結果を閾値(0.007)で1分けて、閾値以上の部分のXY座標の配列をOPTICSアルゴリズムに食わせればOKです。

これで、画像中のどこが文章領域であるかが認識出来ました。

クラスタリング

先ほどの本の表紙をクラス毎にマスクすると、こんな感じになります。

OPTICS結果

スクリーンショット 2020-04-11 9.26.54.png

解像度的に小さな文字は認識出来ていませんが、DPIの設定と文字サイズの兼ね合いで、1画像に含まれる認識したい文字の大きさがあまりにばらけていると、全てを認識する事は難しくなってしまいます。

プログラムコード

対応するプログラムのコードは、misc/detection.pyのこのあたりです。
_get_class関数でクラスタリングの結果を得て、_get_map関数で内接矩形を得ます。そして_filt_map関数で重複する部分を消して大きな領域にマージすると、文章領域毎のマスク画像が出来上がります。

def _get_class(self, im):
    minmax = (im.min(), im.max())
    if minmax[1]-minmax[0] == 0:
        return np.array()
    im = (im-minmax[0]) / (minmax[1]-minmax[0])
    clf = OPTICS(metric='euclidean', min_cluster_size=75)
    a = []
    for x in range(im.shape[0]):
        for y in range(im.shape[1]):
            if im[x][y] > self.sentence_threshold:
                a.append([x,y])
    b = clf.fit_predict(a)
    c = np.zeros(im.shape)
    for i in range(len(b)):
        c[a[i][0],a[i][1]] = b[i]+1
    return c

def _get_map(self, clz_map):
    all_map = []
    for i in range(1,int(np.max(clz_map)+1)):
        clz_wd = np.zeros(clz_map.shape, dtype=np.uint8)
        where = np.where(clz_map == i)
        clz_wd[where] = 255
        cnts = cv2.findContours(clz_wd, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = cnts[0] if len(cnts) == 2 else cnts[1]
        for cnt in cnts:
            rect = cv2.minAreaRect(cnt)
            box = cv2.boxPoints(rect).reshape((-1,1,2)).astype(np.int32)
            cv2.drawContours(clz_wd,[box],0,255,2)
            cv2.drawContours(clz_wd,[box],0,255,-1)
        all_map.append(clz_wd)
    return all_map

def _filt_map(self, all_map):
    maps = []
    dindx = []
    for i1, m1 in enumerate(all_map):
        for i2, m2 in enumerate(all_map):
            if i1 != i2:
                if np.sum(m2[m1 != 0]) == np.sum(m2):
                    dindx.append(i2)
    for i1, m1 in enumerate(all_map):
        if not i1 in dindx:
            for i2, m2 in enumerate(all_map[i1+1:]):
                an = ((m1 == 0) + (m2 == 0)) == 0
                if np.sum(an) != 0:
                    if np.sum(m1) > np.sum(m2):
                        m2[an] = 0
                    else:
                        m1[an] = 0
            maps.append(m1)
    return np.array(maps)

BoundingBoxの作成

後は、Center Pointの認識結果から、それぞれのクラスターに含まれている点(CenterNetに習って周囲8ピクセルより高い値の点を抽出しています)を元にBoundingBoxを作成します。

プログラムコード

周囲8ピクセルより高い値の点を抽出しているのは、misc/detection.pyの以下の箇所です。

def _conv3_filter(self, img, pos, pos_div):
    points = []
    rects = []
    for y in range(1,img.shape[0]-1):
        for x in range(1,img.shape[1]-1):
            if img[y,x]>self.word_threshold and img[y,x]>img[y-1,x] and img[y,x]>img[y-1,x-1] and img[y,x]>img[y-1,x+1] and img[y,x]>img[y,x-1] and img[y,x]>img[y,x+1] and img[y,x]>img[y+1,x] and img[y,x]>img[y+1,x-1] and img[y,x]>img[y+1,x+1]:
                points.append((x,y))
                w, h = pos[0][y//pos_div,x//pos_div]/2, pos[1][y//pos_div,x//pos_div]/2
                w, h = 1+w*img.shape[1], 1+h*img.shape[0]
                offw, offh = int(np.round(w)), int(np.round(h))
                rects.append((x-offw,y-offh,x+offw,y+offh))
    return points, rects

BoundingBoxの拡張

そして、CenterNet同様、BoundingBoxの拡張を行います。

スクリーンショット 2020-04-11 10.04.10.png

上記のフィルターで取り出されたCenterPointの候補と、ニューラルネットワークが出力するBoundingBoxの大きさからBoundingBoxの候補を作り、その周囲に、大きさ1.1倍のBoxと、縦横1.25倍×1.25分の一の細長いBoxを追加してやります。
これにより、候補領域の数を増やすことで、より正確な認識を目指します。
ソースコードは、make_boundingbox関数が対応します。

def make_boundingbox(self, hm_wd, hm_pos, pos_div=2, min_bound=24, resize_val=1.1, aspect_val=1.25):
     pos, rcts = self._conv3_filter(hm_wd, hm_pos, pos_div)
     min_bound = min_bound // 4
     for p, r in zip(pos, rcts):
         x1, y1, x2, y2 = r
         x1 = min(max(0,x1), hm_wd.shape[1])
         y1 = min(max(0,y1), hm_wd.shape[0])
         x2 = min(max(0,x2), hm_wd.shape[1])
         y2 = min(max(0,y2), hm_wd.shape[0])
         w, h = x2-x1, y2-y1
         if min(w, h) >= min_bound:
             self.boundingboxs.append(BoundingBox(x1, y1, x2, y2))
         w2 = int(np.round((x2-x1) * resize_val))
         h2 = int(np.round((y2-y1) * resize_val))
         if w2 != w and h2 != h and min(w2, h2) >= min_bound:
             xx1 = min(max(0, p[0] - w2//2), hm_wd.shape[1])
             yy1 = min(max(0, p[1] - h2//2), hm_wd.shape[0])
             xx2 = min(max(0, p[0] - w2//2 + w2), hm_wd.shape[1])
             yy2 = min(max(0, p[1] - h2//2 + h2), hm_wd.shape[0])
             self.boundingboxs.append(BoundingBox(xx1, yy1, xx2, yy2))
         w2 = int(np.round((x2-x1) * aspect_val))
         h2 = int(np.round((y2-y1) / aspect_val))
         if w2 != w and h2 != h and min(w2, h2) >= min_bound:
             xx1 = min(max(0, p[0] - w2//2), hm_wd.shape[1])
             yy1 = min(max(0, p[1] - h2//2), hm_wd.shape[0])
             xx2 = min(max(0, p[0] - w2//2 + w2), hm_wd.shape[1])
             yy2 = min(max(0, p[1] - h2//2 + h2), hm_wd.shape[0])
             self.boundingboxs.append(BoundingBox(xx1, yy1, xx2, yy2))
         w2 = int(np.round((x2-x1) / aspect_val))
         h2 = int(np.round((y2-y1) * aspect_val))
         if w2 != w and h2 != h and min(w2, h2) >= min_bound:
             xx1 = min(max(0, p[0] - w2//2), hm_wd.shape[1])
             yy1 = min(max(0, p[1] - h2//2), hm_wd.shape[0])
             xx2 = min(max(0, p[0] - w2//2 + w2), hm_wd.shape[1])
             yy2 = min(max(0, p[1] - h2//2 + h2), hm_wd.shape[0])
             self.boundingboxs.append(BoundingBox(xx1, yy1, xx2, yy2))

ただし、その後のクラス分類ニューラルネットワークの実行回数が増えるので、速度はその分だけ犠牲になっています。

出来上がったBoundingBox

そうしてBoundingBoxを全て作ると、以下のようになります。

作成されたBoundingBox

作成されたBoundingBox

実際は、このOCRプログラムでは、その後の文字種類の識別の時に、さらに「縦に繋がっている文字」「横に繋がっている文字」を特別なクラスとして認識してBoundingBoxを増やしたりしています。この処理は冗長なので解説から省きます。

文字のクラス分類

抽出されたBoundingBoxは、まだ文字の候補領域でしかありません。

この段階では、以下のように、重複して認識されたものが沢山あります。

抽出された文字

抽出された文字

これをクラス分類用のニューラルネットワークに全て入力し、認識結果を得ます。

そして、認識結果から、文字として認識したものの内から、その確率を得ます(出力層のsoftmaxの値)。

最も文字として認識された確率と、Center Pointの認識結果から得られる、そこに文字がある確率とを掛け合わせると、BoundingBoxの確率スコアになります。

NMSを実行

そして、BoundingBoxの確率スコアに基づくNonMaxSuppressionを実行すると、OCRの結果が得られます。

実行結果
実行結果

プログラムコード

NonMaxSuppressionのコードは、misc/nms.pyにあります。
あまり実行効率が良いコードでは無いです。
NonMaxSuppressionについてはもっと高速なアルゴリズムが知られているので、書き換えても良いのですが、ニューラルネットワークの実行に比べれば些細な時間なので放置してあります。

def non_max_suppression(boxes, overlapThresh=0.3):
    if len(boxes) == 0:
        return []

    sorted_box = sorted(boxes, key=lambda x:x.score())[::-1]
    ignore_flg = [False] * len(sorted_box)

    for i in range(len(sorted_box)):
        if not ignore_flg[i]:
            for j in range(i+1,len(sorted_box),1):
                r1 = sorted_box[i]
                r2 = sorted_box[j]
                if r1.x1 <= r2.x2 and r2.x1 <= r1.x2 and r1.y1<= r2.y2 and r2.y1 <= r1.y2:
                    w = max(0, min(r1.x2,r2.x2) - max(r1.x1,r2.x1))
                    h = max(0, min(r1.y2,r2.y2) - max(r1.y1,r2.y1))
                    if w * h > (r2.x2-r2.x1)*(r2.y2-r2.y1)*overlapThresh:
                        ignore_flg[j] = True

    return [sorted_box[i] for i in range(len(sorted_box)) if not ignore_flg[i]]

縦書き、横書きの検出

これまでの内容を図にすると、以下のような感じ。

NMSまで

これまでの内容は、文章領域に含まれている文字同士の関係性については触れずに来ました。
つまり、文字が同じ文章に含まれているとは判定出来ても、それぞれの位置を適切に並び替えて、文章としてやらなければなりません。
ここで難しいのは、日本語には縦書きと横書きの両方がある、と言う点です。
また、同じ文章領域に、複数の行が含まれる場合も考えなければなりません。
例えば、上の例は縦書きの例ですが、下の例は横書きでさらに一つの文章領域に複数の行が含まれている場合の例です。

一領域に複数の行

縦書き・横書きの認識と、領域内に含まれている行の認識には、文章領域中の文字の位置(BoundingBoxをNMSした後の領域)を、さらにX座標Y座標毎に、1次元のNonMaxSuppressionを実行してやります。

縦書き・横書きの検出

すると、領域内の行と列の数が解るので、小さい方を文章の行として認識します。

プログラムコード

1次元のNonMaxSuppressionのコードは、misc/nms.py内にあります。

def _1dim_non_suppression(ranges, overlapThresh):
    if len(ranges) == 0:
        return []

    ignore_flg = [False] * len(ranges)

    for i in range(len(ranges)):
        if not ignore_flg[i]:
            for j in range(i+1,len(ranges),1):
                r1 = ranges[i]
                r2 = ranges[j]
                w = max(0, min(r1[1],r2[1]) - max(r1[0],r2[0]))
                if w > (r2[1]-r2[0])*overlapThresh:
                    ignore_flg[j] = True

    return [ranges[i] for i in range(len(ranges)) if not ignore_flg[i]]

後処理

後は、小さな「っ」の認識だとか、カタカナのカと漢字の力のような、同じ形をした文字を、前後の文字種類から推測してやる処理が行われ、最終的な結果が作成されます。
このアルゴリズムは、単純に辞書を作って、直前直後の文字種が一致する場合は置換を行う、という単純なものです。
辞書はmisc/nihongo.pyに定義されています。

# 変換前、変換後、直前の文字種、直後の文字種
filter_word = \
[ ('り', 'リ', katakana, katakana),
  ('リ', 'り', hiragana, hiragana),
  ('へ', 'ヘ', katakana, katakana),
  ('ヘ', 'へ', hiragana, hiragana),
・・・

小さな「っ」の認識とかはものすごく雑で、その後に続くのが「たちつてと」であれば「つ」は問答無用に「っ」に変換、とかしています。
このあたり、日本語文章の文法規則とか、より良いアルゴリズムをご存じの方いらっしゃればご教示くださいませ。

プログラムコード

後処理のフィルター処理は、ocr_japanease.py内のfilter_block関数が実行します。

def filter_block(sent):
    for i in range(len(sent)):
        for j in range(len(filter_word)):
            if filter_word[j][0] == sent[i]:
                bef = filter_word[j][2] is None or (i>0 and sent[i-1] in filter_word[j][2])
                aft = filter_word[j][3] is None or (i<len(sent)-1 and sent[i+1] in filter_word[j][3])
                if bef and aft:
                    sent[i] = filter_word[j][1]

後処理については、このようにごく単純なもので済ませています。
NLPの手法を使ったより良い手法を使いたいところではありますが、色々と複雑になると大変なので、そこまで作り込むことはしませんでした。

まとめ

まとめ。
日本語OCRのプログラムについて、ソースコードも含めて割と詳しめに解説してきました。
それでも1記事で収まっているように、このOCRプログラムは、基本的に凄く単純な作りをしています。
例えば、画像の前処理などは、単純なコントラスト調整やノイズリダクションすら一切入っていないですし、後処理についても性能向上の余地は多分にあります。
制作者の希望としては、このOCRのプログラムを素材として、改良版のプログラムを作ってくれれば嬉しいです。

それでは。

90
92
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
90
92