1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

画像キャプション生成モデル 「CATR」の出力文を入力画像に埋込んでファイル出力するように書き換えてみた

Last updated at Posted at 2021-08-05

__前回の記事__では、CATRモデルの実装コードを、githubから落として、動作検証を行ってみました。

###( 前回の記事 )

今回は、git cloneした推論工程の実行ファイル __predict.py__を改修して、出力されたキャプション文を入力画像に文字列埋込して、結果をファイル保存するように書き換えてみました。

なお、キャプション文が40文字を超える場合は、改行して、改行した結果を入力画像に埋め込むようにしました。4行目以降は、埋め込まずに捨てることにしました。

####( 入力ファイル )

test008.jpeg

####( 出力ファイル )

image_filesディレクトリの直下に、キャプション文を埋め込んだ画像ファイルが出力されます。

test008.jpe_解析結果.jpeg

####( 実行画面 )

Terminal
electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path ./image_files/test008.jpeg --v v3
Using cache found in /Users/ocean/.cache/torch/hub/saahiluppal_catr_master
A man sitting on a beach with a surfboard.
2
electron@diynoMacBook-Pro catr % 

###他の画像ファイルを試してみます。

####( 入力ファイル )

ファイル名: test005.jpg

test005.jpg

####( 出力ファイル )

image_filesディレクトリの直下に、キャプション文を埋め込んだ画像ファイルが出力されます。
出力ファイル名: test005.jp_解析結果.jpeg

test005.jp_解析結果.jpeg

Terminal
electron@diynoMacBook-Pro catr %  ls image_files 
dog.jpg				test002.jpg			test005.jpg			test010.jpg
special.jp_解析結果.jpeg	test003.jpg			test007.jpeg
special.jpg			test004.jpg			test008.jpe_解析結果.jpeg
test001.png			test005.jp_解析結果.jpeg	test008.jpeg
electron@diynoMacBook-Pro catr %  
electron@diynoMacBook-Pro catr %  open ./image_files/test005.jp_解析結果.jpeg 
electron@diynoMacBook-Pro catr %

####( 実行画面 )

Terminal
electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path ./image_files/test005.jpg --v v3 
Using cache found in /Users/ocean/.cache/torch/hub/saahiluppal_catr_master
A man holding a bottle of beer and a man in a yellow shirt.
2
electron@diynoMacBook-Pro catr % 

###もう1つ別の画像ファイルでも試してみます。
####( 入力ファイル )

ファイル名: special.jpg

special.jpg

####( 出力ファイル )

image_filesディレクトリの直下に、キャプション文を埋め込んだ画像ファイルが出力されます。
出力ファイル名: special.jp_解析結果.jpeg

special.jp_解析結果.jpeg

####( 実行画面 )

Terminal
electron@diynoMacBook-Pro catr %  python3 predict_output_file_save.py --path ./image_files/special.jpg --v v3
Using cache found in /Users/ocean/.cache/torch/hub/saahiluppal_catr_master
A picture of a bunch of books and a teddy bear.
2
electron@diynoMacBook-Pro catr % 

#実行したのは、次のスクリプト・ファイルです。

  • スクリプト・ファイル名: predict_output_file_save.py
predict_output_file_save.py
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
import cv2
import matplotlib.pyplot as plt
from models import caption
from datasets import coco, utils
from configuration import Config
import os
import textwrap

parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint

config = Config()

if version == 'v1':
    model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
    model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
    model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
    print("Checking for checkpoint.")
    if checkpoint_path is None:
      raise NotImplementedError('No model to chose from!')
    else:
      if not os.path.exists(checkpoint_path):
        raise NotImplementedError('Give valid checkpoint path')
      print("Found checkpoint! Loading!")
      model,_ = caption.build_model(config)
      print("Loading Checkpoint...")
      checkpoint = torch.load(checkpoint_path, map_location='cpu')
      model.load_state_dict(checkpoint['model'])
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)

image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)


def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template


caption, cap_mask = create_caption_and_mask(
    start_token, config.max_position_embeddings)


@torch.no_grad()
def evaluate():
    model.eval()
    for i in range(config.max_position_embeddings - 1):
        predictions = model(image, caption, cap_mask)
        predictions = predictions[:, i, :]
        predicted_id = torch.argmax(predictions, axis=-1)

        if predicted_id[0] == 102:
            return caption

        caption[:, i+1] = predicted_id[0]
        cap_mask[:, i+1] = False

    return caption


output = evaluate()
result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
#result = tokenizer.decode(output[0], skip_special_tokens=True)
final_caption_text = str(result.capitalize())
#print(result.capitalize())
print(final_caption_text)
#print(result) 上と同じ文字列が出力される

# Pythonで文字列を折り返し・切り詰めして整形する
# https://note.nkmk.me/python-textwrap-wrap-fill-shorten/
# 生成されたキャプション文を40文字で改行する。-> List型が返却される
caption_text_wrap_list = textwrap.wrap(str(final_caption_text), 40)
print(len(caption_text_wrap_list))
# 入力した画像ファイルに、生成したキャプション文の文字列を埋込む
img = cv2.imread(image_path)

#https://qiita.com/Daiki_P/items/91acd5cc208fedd16ee9
img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #BGRをRGBに変換.

if len(caption_text_wrap_list) == 1:
    cv2.putText(img, str(final_caption_text), (0, 50), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
elif len(caption_text_wrap_list) ==2 :
    cv2.putText(img, str(caption_text_wrap_list[0]), (0, 50), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.putText(img, str(caption_text_wrap_list[1]), (0, 90), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
elif len(caption_text_wrap_list) >= 3:
    cv2.putText(img, str(caption_text_wrap_list[0]), (0, 50), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.putText(img, str(caption_text_wrap_list[1]), (0, 90), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.putText(img, str(caption_text_wrap_list[2]), (0, 130), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
    
#output_file_name = "result.jpeg"
# ファイルパスの最期の.jpg, .png, .jpegを削除
output_file_name = "{0}_解析結果.jpeg".format(image_path[0:-1])
cv2.imwrite(output_file_name,img)

#cv2.putText(img_RGB, str(final_caption_text), (0, 100), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
#plt.imshow(img_RGB)
#plt.show()

###( 注意点 )

このスクリプト・ファイルを実行する際は、PC端末がインターネットにつながっている必要があります。
内部処理の中で、BERTを使う場面でHuggingfaceのURLに自動アクセスする処理が走るからです。
インターネットにつながっていない端末で実行すると、以下のエラーが吐かれます。

Terminal
electron@diynoMacBook-Pro catr % # インターネット接尾を切る
zsh: command not found: #
electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path >./image_files/test005.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Traceback (most recent call last):
 File "/usr/local/lib/python3.9/site-packages/urllib3/connection.py", line 169, in >_new_conn
   conn = connection.create_connection(
 File "/usr/local/lib/python3.9/site-packages/urllib3/util/connection.py", line 73, in <create_connection
   for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
 File >"/usr/local/Cellar/python@3.9/3.9.6/Frameworks/Python.framework/Versions/3.9/lib/python3>.9/socket.py", line 953, in getaddrinfo
   for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
socket.gaierror: [Errno 8] nodename nor servname provided, or not known

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
 File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 699, in >urlopen
   httplib_response = self._make_request(
 File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 382, in >_make_request
   self._validate_conn(conn)
 File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 1010, in >_validate_conn
   conn.connect()
 File "/usr/local/lib/python3.9/site-packages/urllib3/connection.py", line 353, in >connect
   conn = self._new_conn()
 File "/usr/local/lib/python3.9/site-packages/urllib3/connection.py", line 181, in >_new_conn
   raise NewConnectionError(
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at >0x1665cc7c0>: Failed to establish a new connection: [Errno 8] nodename nor servname >provided, or not known

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
 File "/usr/local/lib/python3.9/site-packages/requests/adapters.py", line 439, in send
   resp = conn.urlopen(
 File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 755, in >urlopen
   retries = retries.increment(
 File "/usr/local/lib/python3.9/site-packages/urllib3/util/retry.py", line 573, in >increment
   raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): >Max retries exceeded with url: /api/models/bert-base-uncased (Caused by >NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x1665cc7c0>: Failed >to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
 File "/Users/electron/Desktop/catr/predict_output_file_save.py", line 42, in <module>
   tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 File "/usr/local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", >line 1647, in from_pretrained
   fast_tokenizer_file = get_fast_tokenizer_file(
 File "/usr/local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", >line 3408, in get_fast_tokenizer_file
   all_files = get_list_of_files(path_or_repo, revision=revision, >use_auth_token=use_auth_token)
 File "/usr/local/lib/python3.9/site-packages/transformers/file_utils.py", line 1686, >in get_list_of_files
   model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info(
 File "/usr/local/lib/python3.9/site-packages/huggingface_hub/hf_api.py", line 247, in >model_info
   r = requests.get(path, headers=headers)
 File "/usr/local/lib/python3.9/site-packages/requests/api.py", line 76, in get
   return request('get', url, params=params, **kwargs)
 File "/usr/local/lib/python3.9/site-packages/requests/api.py", line 61, in request
   return session.request(method=method, url=url, **kwargs)
 File "/usr/local/lib/python3.9/site-packages/requests/sessions.py", line 542, in >request
   resp = self.send(prep, **send_kwargs)
 File "/usr/local/lib/python3.9/site-packages/requests/sessions.py", line 655, in send
   r = adapter.send(request, **kwargs)
 File "/usr/local/lib/python3.9/site-packages/requests/adapters.py", line 516, in send
   raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPSConnectionPool(host='huggingface.co', >port=443): Max retries exceeded with url: /api/models/bert-base-uncased (Caused by >NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x1665cc7c0>: Failed >to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))
electron@diynoMacBook-Pro catr % 
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?