#動かしたコード
__CAption TRansformerのPyTorch実装__がありましたので、git cloneして動かしてみた。
#動かしてみた
git cloneすると、学習済みの__pre-trained model__が同梱されています・
今回は、学習済みモデルに、画像ファイルを食わせて、その画像ファイルに対応する説明文(キャップション文)を出力させてみます。
学習済みモデルを使うモデル推論工程は、__predict.py*__を使います。
食わせる画像ファイルは、コマンドライン引数で、画像のパスを渡します。
electron@diynoMacBook-Pro catr % cat predict.py | grep image_path
image_path = args.path
image = Image.open(image_path)
electron@diynoMacBook-Pro catr %
####Google画像検索で、画像ファイルを取ってくる
electron@diynoMacBook-Pro catr % mkdir image_files
electron@diynoMacBook-Pro catr % cd image_files
electron@diynoMacBook-Pro image_files % ls
test001.png
electron@diynoMacBook-Pro image_files %
###取得した画像ファイルを渡して、キャプション文を出してみる
- 学習済みモデルは、v1, v2, v3を選べますが、まずはv2を選びます。
実行すると、v2モデルの学習済みの重みパラメータファイルのダウンロードが始まります。
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v2
Downloading: "https://github.com/saahiluppal/catr/archive/master.zip" to /Users/electron/.cache/torch/hub/master.zip
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/electron/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [00:13<00:00, 13.6MB/s]
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight389123791.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight389123791.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:26<00:00, 13.0MB/s]
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 792kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 7.09kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 466k/466k [00:00<00:00, 1.06MB/s]
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 570/570 [00:00<00:00, 144kB/s]
Traceback (most recent call last):
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man riding on the back of a brown horse.
electron@diynoMacBook-Pro catr %
##[結果] A man riding on the back of a brown horse.
__A man riding on the back of a brown horse.__という文が、Terminalに標準出力されました。
なお、v2を最初に選択したときに、v2モデルの学習済み重みファイルは、__torch.hubから__ダウンロードされます。
predict.pyif version == 'v1': model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True) elif version == 'v2': model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True) elif version == 'v3': model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
####他の画像で試してみます。
electron@diynoMacBook-Pro catr % ls ./image_files
test001.png test002.jpg
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of men laying on top of a bed.
electron@diynoMacBook-Pro catr %
##[結果] A group of men laying on top of a bed.
今度は、__A group of men laying on top of a bed.__という文が返されました。
####学習済みモデルのv3を使ってみる
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight493084032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:24<00:00, 13.9MB/s]
A group of men laying on top of a beach.
electron@diynoMacBook-Pro catr %
__同じ文__が返されました。
v1ではどうでしょうか。
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v1
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weights_9348032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:23<00:00, 14.4MB/s]
A man in a wetsuit laying on a beach.
electron@diynoMacBook-Pro catr %
##[結果] A man in a wetsuit laying on a beach.
今度は、違う文__が返されました。
返された文は、A man in a wetsuit laying on a beach.__です。
v4はないみたいです。
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v4
Checking for checkpoint.
Traceback (most recent call last):
File "/Users/electron/Desktop/catr/predict.py", line 32, in <module>
raise NotImplementedError('No model to chose from!')
NotImplementedError: No model to chose from!
electron@diynoMacBook-Pro catr %
最初の画像を、v3モデルに与えてみます。
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man is riding a horse in a field
electron@diynoMacBook-Pro catr %
##[結果] A man is riding a horse in a field
最初のv2モデルとは違う文が提案されました。
なお、今度は最期にピリオドが出力されていないのが面白いです。
A man is riding a horse in a field
###第3の画像を与えてみる
electron@diynoMacBook-Pro catr % ls ./image_files
test001.png test002.jpg test003.jpg
electron@diynoMacBook-Pro catr %
v3モデルを選びます。
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test003.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of young people posing for a picture.
electron@diynoMacBook-Pro catr %
##[結果] A group of young people posing for a picture.
以下の文が返されました。
A group of young people posing for a picture.
#環境構築
electron@diynoMacBook-Pro Desktop % git clone https://github.com/saahiluppal/catr.git
__catrディレクトリ__の中身
electron@diynoMacBook-Pro Desktop % ls catr
LICENSE catr_demo.ipynb datasets finetune.py main.py predict.py
README.md configuration.py engine.py hubconf.py models requirements.txt
electron@diynoMacBook-Pro Desktop %
__requirement.txt__に列挙されているImportするモジュール一覧
electron@diynoMacBook-Pro Desktop % cd catr
electron@diynoMacBook-Pro catr % cat requirements.txt
torch
torchvision
numpy
transformers
tqdm% electron@diynoMacBook-Pro catr %
importする
electron@diynoMacBook-Pro catr % pip3 install -r requirements.txt
( 省略 )
Successfully installed click-8.0.1 huggingface-hub-0.0.12 joblib-1.0.1 packaging-21.0 regex-2021.8.3 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.9.1
WARNING: You are using pip version 21.2.1; however, version 21.2.2 is available.
You should consider upgrading via the '/usr/local/opt/python@3.9/bin/python3.9 -m pip install --upgrade pip' command.
electron@diynoMacBook-Pro catr %
Transformerを用いた画像キャプション文生成モデルということで、
Transformerの実装コードとして、Huggingfaceのものが入った。
electron@diynoMacBook-Pro catr % ls
LICENSE catr_demo.ipynb datasets finetune.py main.py predict.py
README.md configuration.py engine.py hubconf.py models requirements.txt
electron@diynoMacBook-Pro catr %
今回は、自前の学習データ(画像とキャプション文のペア集合)を学ばせるモデル学習は行わない。
すでに学習済みのpre-trainedモデルに対象画像を与えて、入力した画像を説明するキャプション文を出力させる「推論」工程だけを行います。
「推論」工程は、__predict.py__を使います。
Pythonのどのモジュールをimportしているか、確認します。
electron@diynoMacBook-Pro catr % head -50 predict.py
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
from models import caption
from datasets import coco, utils
from configuration import Config
import os
parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint
config = Config()
if version == 'v1':
model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
print("Checking for checkpoint.")
if checkpoint_path is None:
raise NotImplementedError('No model to chose from!')
else:
if not os.path.exists(checkpoint_path):
raise NotImplementedError('Give valid checkpoint path')
print("Found checkpoint! Loading!")
model,_ = caption.build_model(config)
print("Loading Checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)
image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)
electron@diynoMacBook-Pro catr %
以下の__models, datasets, configuratoon*__は、git cloneした同梱されている資材と思われる。
__predict.py__を別のディレクトリに移動したり、別のスクリプトファイルにコードを移植する場合は、これらの資材も、そのスクリプトファイルの格納先と同じ場所に置く必要がありそうです。
from models import caption
from datasets import coco, utils
from configuration import Config
次に、対象となる画像ファイルをどこで読み込んでいるのか、を見ていきます。
electron@diynoMacBook-Pro catr % cat predict.py | grep image
parser.add_argument('--path', type=str, help='path to image', required=True)
image_path = args.path
image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)
predictions = model(image, caption, cap_mask)
electron@diynoMacBook-Pro catr %
argparseを使って、コマンドライン引数で、画像のパスを受け取っていることが分かりました。
electron@diynoMacBook-Pro catr % cat predict.py | grep image_path
image_path = args.path
image = Image.open(image_path)
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % mkdir image_files
electron@diynoMacBook-Pro catr % cd image_files
electron@diynoMacBook-Pro image_files % ls
test001.png
electron@diynoMacBook-Pro image_files %
electron@diynoMacBook-Pro image_files % cd ..
electron@diynoMacBook-Pro catr % cat predict.py | grep args
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint
electron@diynoMacBook-Pro catr % cat predict.py | grep argument
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path /image_files/test001.png --v v2
Downloading: "https://github.com/saahiluppal/catr/archive/master.zip" to /Users/electron/.cache/torch/hub/master.zip
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/electron/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [00:13<00:00, 13.6MB/s]
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight389123791.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight389123791.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:26<00:00, 13.0MB/s]
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 792kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 7.09kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 466k/466k [00:00<00:00, 1.06MB/s]
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 570/570 [00:00<00:00, 144kB/s]
Traceback (most recent call last):
File "/Users/electron/Desktop/catr/predict.py", line 46, in <module>
image = Image.open(image_path)
File "/usr/local/lib/python3.9/site-packages/PIL/Image.py", line 2968, in open
fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/image_files/test001.png'
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man riding on the back of a brown horse.
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % ls ./image_files/
test001.png
electron@diynoMacBook-Pro catr % ls /image_files/
ls: /image_files/: No such file or directory
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % ls /image_files/test001.png
ls: /image_files/test001.png: No such file or directory
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % ls ./image_files/test001.png
./image_files/test001.png
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % ls
LICENSE catr_demo.ipynb engine.py image_files predict.py
README.md configuration.py finetune.py main.py requirements.txt
__pycache__ datasets hubconf.py models
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % ls ./image_files
test001.png test002.jpg
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path /image_files/test002.jpg --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Traceback (most recent call last):
File "/Users/electron/Desktop/catr/predict.py", line 46, in <module>
image = Image.open(image_path)
File "/usr/local/lib/python3.9/site-packages/PIL/Image.py", line 2968, in open
fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/image_files/test002.jpg'
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of men laying on top of a bed.
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight493084032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:24<00:00, 13.9MB/s]
A group of men laying on top of a beach.
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v1
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weights_9348032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:23<00:00, 14.4MB/s]
A man in a wetsuit laying on a beach.
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v4
Checking for checkpoint.
Traceback (most recent call last):
File "/Users/electron/Desktop/catr/predict.py", line 32, in <module>
raise NotImplementedError('No model to chose from!')
NotImplementedError: No model to chose from!
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man is riding a horse in a field
electron@diynoMacBook-Pro catr % ls ./image_files
test001.png test002.jpg test003.jpg
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test003.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of young people posing for a picture.
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % cat predict.py | import
zsh: command not found: import
electron@diynoMacBook-Pro catr % cat predict.py | grep import
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
from models import caption
from datasets import coco, utils
from configuration import Config
import os
electron@diynoMacBook-Pro catr % ls
LICENSE catr_demo.ipynb engine.py image_files predict.py
README.md configuration.py finetune.py main.py requirements.txt
__pycache__ datasets hubconf.py models
electron@diynoMacBook-Pro catr %