__前回の記事__では、CATRモデルの実装コードを、githubから落として、動作検証を行ってみました。
###( 前回の記事 )
今回は、git cloneした推論工程の実行ファイル __predict.py__を改修して、出力されたキャプション文を入力画像に文字列埋込して、結果をファイル保存するように書き換えてみました。
なお、キャプション文が40文字を超える場合は、改行して、改行した結果を入力画像に埋め込むようにしました。4行目以降は、埋め込まずに捨てることにしました。
####( 入力ファイル )
####( 出力ファイル )
image_filesディレクトリの直下に、キャプション文を埋め込んだ画像ファイルが出力されます。
####( 実行画面 )
electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path ./image_files/test008.jpeg --v v3
Using cache found in /Users/ocean/.cache/torch/hub/saahiluppal_catr_master
A man sitting on a beach with a surfboard.
2
electron@diynoMacBook-Pro catr %
###他の画像ファイルを試してみます。
####( 入力ファイル )
ファイル名: test005.jpg
####( 出力ファイル )
image_filesディレクトリの直下に、キャプション文を埋め込んだ画像ファイルが出力されます。
出力ファイル名: test005.jp_解析結果.jpeg
electron@diynoMacBook-Pro catr % ls image_files
dog.jpg test002.jpg test005.jpg test010.jpg
special.jp_解析結果.jpeg test003.jpg test007.jpeg
special.jpg test004.jpg test008.jpe_解析結果.jpeg
test001.png test005.jp_解析結果.jpeg test008.jpeg
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % open ./image_files/test005.jp_解析結果.jpeg
electron@diynoMacBook-Pro catr %
####( 実行画面 )
electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path ./image_files/test005.jpg --v v3
Using cache found in /Users/ocean/.cache/torch/hub/saahiluppal_catr_master
A man holding a bottle of beer and a man in a yellow shirt.
2
electron@diynoMacBook-Pro catr %
###もう1つ別の画像ファイルでも試してみます。
####( 入力ファイル )
ファイル名: special.jpg
####( 出力ファイル )
image_filesディレクトリの直下に、キャプション文を埋め込んだ画像ファイルが出力されます。
出力ファイル名: special.jp_解析結果.jpeg
####( 実行画面 )
electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path ./image_files/special.jpg --v v3
Using cache found in /Users/ocean/.cache/torch/hub/saahiluppal_catr_master
A picture of a bunch of books and a teddy bear.
2
electron@diynoMacBook-Pro catr %
#実行したのは、次のスクリプト・ファイルです。
- スクリプト・ファイル名: predict_output_file_save.py
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
import cv2
import matplotlib.pyplot as plt
from models import caption
from datasets import coco, utils
from configuration import Config
import os
import textwrap
parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint
config = Config()
if version == 'v1':
model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
print("Checking for checkpoint.")
if checkpoint_path is None:
raise NotImplementedError('No model to chose from!')
else:
if not os.path.exists(checkpoint_path):
raise NotImplementedError('Give valid checkpoint path')
print("Found checkpoint! Loading!")
model,_ = caption.build_model(config)
print("Loading Checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)
image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)
def create_caption_and_mask(start_token, max_length):
caption_template = torch.zeros((1, max_length), dtype=torch.long)
mask_template = torch.ones((1, max_length), dtype=torch.bool)
caption_template[:, 0] = start_token
mask_template[:, 0] = False
return caption_template, mask_template
caption, cap_mask = create_caption_and_mask(
start_token, config.max_position_embeddings)
@torch.no_grad()
def evaluate():
model.eval()
for i in range(config.max_position_embeddings - 1):
predictions = model(image, caption, cap_mask)
predictions = predictions[:, i, :]
predicted_id = torch.argmax(predictions, axis=-1)
if predicted_id[0] == 102:
return caption
caption[:, i+1] = predicted_id[0]
cap_mask[:, i+1] = False
return caption
output = evaluate()
result = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
#result = tokenizer.decode(output[0], skip_special_tokens=True)
final_caption_text = str(result.capitalize())
#print(result.capitalize())
print(final_caption_text)
#print(result) 上と同じ文字列が出力される
# Pythonで文字列を折り返し・切り詰めして整形する
# https://note.nkmk.me/python-textwrap-wrap-fill-shorten/
# 生成されたキャプション文を40文字で改行する。-> List型が返却される
caption_text_wrap_list = textwrap.wrap(str(final_caption_text), 40)
print(len(caption_text_wrap_list))
# 入力した画像ファイルに、生成したキャプション文の文字列を埋込む
img = cv2.imread(image_path)
#https://qiita.com/Daiki_P/items/91acd5cc208fedd16ee9
img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #BGRをRGBに変換.
if len(caption_text_wrap_list) == 1:
cv2.putText(img, str(final_caption_text), (0, 50), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
elif len(caption_text_wrap_list) ==2 :
cv2.putText(img, str(caption_text_wrap_list[0]), (0, 50), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
cv2.putText(img, str(caption_text_wrap_list[1]), (0, 90), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
elif len(caption_text_wrap_list) >= 3:
cv2.putText(img, str(caption_text_wrap_list[0]), (0, 50), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
cv2.putText(img, str(caption_text_wrap_list[1]), (0, 90), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
cv2.putText(img, str(caption_text_wrap_list[2]), (0, 130), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
#output_file_name = "result.jpeg"
# ファイルパスの最期の.jpg, .png, .jpegを削除
output_file_name = "{0}_解析結果.jpeg".format(image_path[0:-1])
cv2.imwrite(output_file_name,img)
#cv2.putText(img_RGB, str(final_caption_text), (0, 100), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 1, cv2.LINE_AA)
#plt.imshow(img_RGB)
#plt.show()
###( 注意点 )
このスクリプト・ファイルを実行する際は、PC端末がインターネットにつながっている必要があります。
内部処理の中で、BERTを使う場面でHuggingfaceのURLに自動アクセスする処理が走るからです。
インターネットにつながっていない端末で実行すると、以下のエラーが吐かれます。
Terminalelectron@diynoMacBook-Pro catr % # インターネット接尾を切る zsh: command not found: # electron@diynoMacBook-Pro catr % python3 predict_output_file_save.py --path >./image_files/test005.jpg --v v3 Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master Traceback (most recent call last): File "/usr/local/lib/python3.9/site-packages/urllib3/connection.py", line 169, in >_new_conn conn = connection.create_connection( File "/usr/local/lib/python3.9/site-packages/urllib3/util/connection.py", line 73, in <create_connection for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): File >"/usr/local/Cellar/python@3.9/3.9.6/Frameworks/Python.framework/Versions/3.9/lib/python3>.9/socket.py", line 953, in getaddrinfo for res in _socket.getaddrinfo(host, port, family, type, proto, flags): socket.gaierror: [Errno 8] nodename nor servname provided, or not known During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 699, in >urlopen httplib_response = self._make_request( File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 382, in >_make_request self._validate_conn(conn) File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 1010, in >_validate_conn conn.connect() File "/usr/local/lib/python3.9/site-packages/urllib3/connection.py", line 353, in >connect conn = self._new_conn() File "/usr/local/lib/python3.9/site-packages/urllib3/connection.py", line 181, in >_new_conn raise NewConnectionError( urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at >0x1665cc7c0>: Failed to establish a new connection: [Errno 8] nodename nor servname >provided, or not known During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/usr/local/lib/python3.9/site-packages/requests/adapters.py", line 439, in send resp = conn.urlopen( File "/usr/local/lib/python3.9/site-packages/urllib3/connectionpool.py", line 755, in >urlopen retries = retries.increment( File "/usr/local/lib/python3.9/site-packages/urllib3/util/retry.py", line 573, in >increment raise MaxRetryError(_pool, url, error or ResponseError(cause)) urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): >Max retries exceeded with url: /api/models/bert-base-uncased (Caused by >NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x1665cc7c0>: Failed >to establish a new connection: [Errno 8] nodename nor servname provided, or not known')) During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/Users/electron/Desktop/catr/predict_output_file_save.py", line 42, in <module> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') File "/usr/local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", >line 1647, in from_pretrained fast_tokenizer_file = get_fast_tokenizer_file( File "/usr/local/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", >line 3408, in get_fast_tokenizer_file all_files = get_list_of_files(path_or_repo, revision=revision, >use_auth_token=use_auth_token) File "/usr/local/lib/python3.9/site-packages/transformers/file_utils.py", line 1686, >in get_list_of_files model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( File "/usr/local/lib/python3.9/site-packages/huggingface_hub/hf_api.py", line 247, in >model_info r = requests.get(path, headers=headers) File "/usr/local/lib/python3.9/site-packages/requests/api.py", line 76, in get return request('get', url, params=params, **kwargs) File "/usr/local/lib/python3.9/site-packages/requests/api.py", line 61, in request return session.request(method=method, url=url, **kwargs) File "/usr/local/lib/python3.9/site-packages/requests/sessions.py", line 542, in >request resp = self.send(prep, **send_kwargs) File "/usr/local/lib/python3.9/site-packages/requests/sessions.py", line 655, in send r = adapter.send(request, **kwargs) File "/usr/local/lib/python3.9/site-packages/requests/adapters.py", line 516, in send raise ConnectionError(e, request=request) requests.exceptions.ConnectionError: HTTPSConnectionPool(host='huggingface.co', >port=443): Max retries exceeded with url: /api/models/bert-base-uncased (Caused by >NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x1665cc7c0>: Failed >to establish a new connection: [Errno 8] nodename nor servname provided, or not known')) electron@diynoMacBook-Pro catr %