0. はじめに
過去にOCR記事を2本書きました。
これが未だに閲覧数が多いので、今年のアドカレネタにもOCRを持ってきました。
結論から言うと今回はOCRというカテゴリではないのでタイトルにも?を入れましたが、まぁ細かいことは気にしない気にしない...(以後も簡単のためOCRと言ってます、ご了承ください)
ということで今回はDonutを紹介します。
- 動作環境
- OS : Windows10 pro
- GPU : GeForce RTX 2060(6GB)
- python: 3.10.11
- transformers: 4.25.1
- Pytorch: 1.12.1 (+cu116)
- protobuf
- sentencepiece
- pillow
- jupyter notebook
1. Donutに関して
「Donut」とは、「Document understanding transformer」の略称で、文書理解のための新しい方法です。
LINEでおなじみ韓国のNaver社が開発したMIT licenseのOSSで、従来のOCR手法のように「画像をテキストに変換してから解析」せずに「画像を直接解析」します。
つまり文字を1文字1文字認識するのではなく、画像が何を表すか?を理解してからそれを基にテキストを生成するという画期的なアプローチをとっています。
細かい理論は以下論文に書いてあるので、気になる方は読んでみてください。
・以下が元論文
・以下がgithub
Donutの仕組み:
・エンコーダでは画像をSwin Transformer
を使用して特徴抽出し、ベクトル化
・デコーダーでは事前に色々な言語で学習されたBART(Bidirectional and Auto-Regressive Transformers)
を使うことでテキストを生成
・つまり、あくまで画像の特徴から文字を生成する仕組みなのでOCRではない
・このアプローチにより、速度と精度の両方で優れたパフォーマンスを示せる
2. Donutを検証してみる
では実際にDonutを使用するサンプルコードを記載する。
なお、今回は簡単の為事前学習済モデルをそのまま活用することにする。
ファインチューニングすれば精度があがる可能性は高いが、データ集め等労力もかかるので今回は言及しない。
2-1. ライブラリの導入
私は以下をpipで入れました(実際にはpython環境管理は便利なryeを使用してるのでpip不使用)
pip install ipykernel
pip install torch ※torchに関しては皆さんの環境に合わせて適切なコマンド入れてください
pip install transformers
pip install datasets
pip install pillow
pip install sentencepiece
pip install protobuf
2-2. 学習で使用された日本語データを見てみる
今回使用する「事前学習済モデル」では4か国語で学習されているらしいです。
(英語/中国語/日本語/韓国語それぞれ500データずつ)
そのうち日本語ではどんなもので学習されたか見てみます。
ただし、全部をloadすると無茶苦茶時間がかかるので、2枚目の画像を見てみることにする
※なんで2枚目かというと、1枚目はよくわからんやつなので(2枚目もイミフですが1枚目よりマシ?)
from datasets import load_dataset
# 日本語のデータセットを「streaming=True」でloadする(synthdog-ja)
image_dataset = load_dataset('naver-clova-ix/synthdog-ja', streaming=True)
# 2枚目の画像を取得
counter = 0
for example in image_dataset["train"]:
counter += 1
if counter == 2:
# ここで2枚目の画像を処理
break
# 変数化(後でも使う)
image = image_example["image"]
# notebook上に表示してみる
display(image)
2-3.この画像でOCR精度を見てみる
学習に使っているので正解するとは思いますが、まぁ見てみましょう
使用するモデルは以下huggingfaceから「donut-base」というものを使用します
なお、コードの中で「task」というものが出てきますが以下の種類があるっぽいです。
データセットや解きたいタスクで選ぶみたいですね。
主要なTask:
・文章抽出 : CORD(s_cord-v2) / SynthDoG(s_synthdog)
・画像分類 : RVL-CDIP(s_rvlcdip)
・QA(画像に質問し、画像から回答を得る) : DocVQA(s_docvqa)
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
# モデルを指定(donut-base)
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
# GPU適応
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# モデルに入力用の画像を準備する(datasetsで変数化したimageを使う)
pixel_values = processor(image, return_tensors="pt").pixel_values
# タスクを指定する(レシート読み取りならCORDらしいので今回はSynthDoG)
task_name = "synthdog"
task_prompt = f"<s_{task_name}>"
# デコーダーinput
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# output
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# 文字を生成(デコード)
sequence = processor.batch_decode(outputs.sequences)[0]
# 後ろにくっついてるpaddingトークン</s>を消す
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
# 先頭にくっついてるtaskトークン<s_synthdog>を消す
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
# わかりやすく半角スペースで改行して出力
print(processor.token2json(sequence)["text_sequence"].replace(" ", "\n"))
どうやら合っていそうである。
2-4. 自分の画像でOCR検証
さて、いつものように青空文庫からスクショしてきた「宮沢賢治」の画像でOCRをしてみることにする。どれくらいの精度が出せるだろうか?
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
from PIL import Image
# モデル
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# 宮沢画像読み込み
image = Image.open('miyazawa.png').convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
# タスク(synthdog)
task_name = "synthdog"
task_prompt = f"<s_{task_name}>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# output
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
print(processor.token2json(sequence)["text_sequence"].replace(" ", "\n"))
微妙そうに見えて9割くらいは正解していそうかな??
ただそもそも学習のトークンに含まれてないものは生成できないですので、以下に含まれないものは生成できませんね(宮沢賢治の「賢」は含まれてませんでした)
https://huggingface.co/naver-clova-ix/donut-base/raw/main/tokenizer.json
3. おわりに
純粋なOCRではないかもしれませんが、標準的な文章ならまぁまぁな精度が出そうな気がしました。
GPTsでもOCRはできますが、一度は試してみてはいかがでしょうか?
なお、Donutを使って日本語で調整されたモデルは現時点で以下1個しかなかったです。
今後増えるかは事例次第ですかね??
それでは今回はここまで
参考