4
3

More than 3 years have passed since last update.

Transformerベースの画像キャプション生成モデル CATRを動かしてみた

Last updated at Posted at 2021-08-05

動かしたコード

CAption TRansformerのPyTorch実装がありましたので、git cloneして動かしてみた。

動かしてみた

git cloneすると、学習済みのpre-trained modelが同梱されています・

今回は、学習済みモデルに、画像ファイルを食わせて、その画像ファイルに対応する説明文(キャップション文)を出力させてみます。

学習済みモデルを使うモデル推論工程は、predict.py*を使います。
食わせる画像ファイルは、コマンドライン引数で、画像のパスを渡します。

Terminal
electron@diynoMacBook-Pro catr % cat predict.py | grep image_path
image_path = args.path
image = Image.open(image_path)
electron@diynoMacBook-Pro catr %

Google画像検索で、画像ファイルを取ってくる

test001.png

Terminal
electron@diynoMacBook-Pro catr % mkdir image_files
electron@diynoMacBook-Pro catr % cd image_files 
electron@diynoMacBook-Pro image_files % ls
test001.png
electron@diynoMacBook-Pro image_files % 

取得した画像ファイルを渡して、キャプション文を出してみる

  • 学習済みモデルは、v1, v2, v3を選べますが、まずはv2を選びます。

実行すると、v2モデルの学習済みの重みパラメータファイルのダウンロードが始まります。

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v2
Downloading: "https://github.com/saahiluppal/catr/archive/master.zip" to /Users/electron/.cache/torch/hub/master.zip
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/electron/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [00:13<00:00, 13.6MB/s]
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight389123791.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight389123791.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:26<00:00, 13.0MB/s]
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 792kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 7.09kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 466k/466k [00:00<00:00, 1.06MB/s]
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 570/570 [00:00<00:00, 144kB/s]
Traceback (most recent call last):
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man riding on the back of a brown horse.
electron@diynoMacBook-Pro catr % 

[結果] A man riding on the back of a brown horse.

A man riding on the back of a brown horse.という文が、Terminalに標準出力されました。

なお、v2を最初に選択したときに、v2モデルの学習済み重みファイルは、torch.hubからダウンロードされます。

predict.py
if version == 'v1':
   model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
   model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
   model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)

他の画像で試してみます。

Terminal
electron@diynoMacBook-Pro catr % ls ./image_files 
test001.png test002.jpg
electron@diynoMacBook-Pro catr % 

test002.jpg

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of men laying on top of a bed.
electron@diynoMacBook-Pro catr % 

[結果] A group of men laying on top of a bed.

今度は、A group of men laying on top of a bed.という文が返されました。

学習済みモデルのv3を使ってみる

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight493084032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:24<00:00, 13.9MB/s]
A group of men laying on top of a beach.
electron@diynoMacBook-Pro catr % 

同じ文が返されました。

v1ではどうでしょうか。

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v1
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weights_9348032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:23<00:00, 14.4MB/s]
A man in a wetsuit laying on a beach.
electron@diynoMacBook-Pro catr % 

[結果] A man in a wetsuit laying on a beach.

今度は、違う文が返されました。
返された文は、A man in a wetsuit laying on a beach.です。

v4はないみたいです。

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v4
Checking for checkpoint.
Traceback (most recent call last):
  File "/Users/electron/Desktop/catr/predict.py", line 32, in <module>
    raise NotImplementedError('No model to chose from!')
NotImplementedError: No model to chose from!
electron@diynoMacBook-Pro catr % 

最初の画像を、v3モデルに与えてみます。

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man is riding a horse in a field

electron@diynoMacBook-Pro catr %

[結果] A man is riding a horse in a field

最初のv2モデルとは違う文が提案されました。
なお、今度は最期にピリオドが出力されていないのが面白いです。

A man is riding a horse in a field

第3の画像を与えてみる

Terminal
electron@diynoMacBook-Pro catr % ls ./image_files                                         
test001.png test002.jpg test003.jpg
electron@diynoMacBook-Pro catr % 

test003.jpg

v3モデルを選びます。

Terminal
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test003.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of young people posing for a picture.
electron@diynoMacBook-Pro catr % 

[結果] A group of young people posing for a picture.

以下の文が返されました。

A group of young people posing for a picture.

環境構築

Terminal
electron@diynoMacBook-Pro Desktop % git clone https://github.com/saahiluppal/catr.git

catrディレクトリの中身

Terminal
electron@diynoMacBook-Pro Desktop % ls catr
LICENSE         catr_demo.ipynb     datasets        finetune.py     main.py         predict.py
README.md       configuration.py    engine.py       hubconf.py      models          requirements.txt
electron@diynoMacBook-Pro Desktop %

requirement.txtに列挙されているImportするモジュール一覧

Terminal
electron@diynoMacBook-Pro Desktop % cd catr 
electron@diynoMacBook-Pro catr % cat requirements.txt
torch
torchvision
numpy
transformers
tqdm%                                                                                                                                                             electron@diynoMacBook-Pro catr %

importする

Terminal
electron@diynoMacBook-Pro catr % pip3 install -r requirements.txt

( 省略 )

Successfully installed click-8.0.1 huggingface-hub-0.0.12 joblib-1.0.1 packaging-21.0 regex-2021.8.3 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.9.1
WARNING: You are using pip version 21.2.1; however, version 21.2.2 is available.
You should consider upgrading via the '/usr/local/opt/python@3.9/bin/python3.9 -m pip install --upgrade pip' command.
electron@diynoMacBook-Pro catr % 

Transformerを用いた画像キャプション文生成モデルということで、
Transformerの実装コードとして、Huggingfaceのものが入った。

Terminal
electron@diynoMacBook-Pro catr % ls
LICENSE         catr_demo.ipynb     datasets        finetune.py     main.py         predict.py
README.md       configuration.py    engine.py       hubconf.py      models          requirements.txt
electron@diynoMacBook-Pro catr %

今回は、自前の学習データ(画像とキャプション文のペア集合)を学ばせるモデル学習は行わない。

すでに学習済みのpre-trainedモデルに対象画像を与えて、入力した画像を説明するキャプション文を出力させる「推論」工程だけを行います。

「推論」工程は、predict.pyを使います。
Pythonのどのモジュールをimportしているか、確認します。

Terminal
electron@diynoMacBook-Pro catr % head -50 predict.py
import torch

from transformers import BertTokenizer
from PIL import Image
import argparse

from models import caption
from datasets import coco, utils
from configuration import Config
import os

parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint

config = Config()

if version == 'v1':
    model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
    model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
    model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
    print("Checking for checkpoint.")
    if checkpoint_path is None:
      raise NotImplementedError('No model to chose from!')
    else:
      if not os.path.exists(checkpoint_path):
        raise NotImplementedError('Give valid checkpoint path')
      print("Found checkpoint! Loading!")
      model,_ = caption.build_model(config)
      print("Loading Checkpoint...")
      checkpoint = torch.load(checkpoint_path, map_location='cpu')
      model.load_state_dict(checkpoint['model'])
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)

image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)

electron@diynoMacBook-Pro catr %

以下のmodels, datasets, configuratoon*は、git cloneした同梱されている資材と思われる。
predict.pyを別のディレクトリに移動したり、別のスクリプトファイルにコードを移植する場合は、これらの資材も、そのスクリプトファイルの格納先と同じ場所に置く必要がありそうです。

from models import caption
from datasets import coco, utils
from configuration import Config

次に、対象となる画像ファイルをどこで読み込んでいるのか、を見ていきます。

Terminal
electron@diynoMacBook-Pro catr % cat predict.py | grep image
parser.add_argument('--path', type=str, help='path to image', required=True)
image_path = args.path
image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)
        predictions = model(image, caption, cap_mask)
electron@diynoMacBook-Pro catr % 

argparseを使って、コマンドライン引数で、画像のパスを受け取っていることが分かりました。

Terminal
electron@diynoMacBook-Pro catr % cat predict.py | grep image_path
image_path = args.path
image = Image.open(image_path)
electron@diynoMacBook-Pro catr %
electron@diynoMacBook-Pro catr % mkdir image_files
electron@diynoMacBook-Pro catr % cd image_files 
electron@diynoMacBook-Pro image_files % ls
test001.png
electron@diynoMacBook-Pro image_files % 
electron@diynoMacBook-Pro image_files % cd ..
electron@diynoMacBook-Pro catr % cat predict.py | grep args      
args = parser.parse_args()
image_path = args.path
version = args.v
checkpoint_path = args.checkpoint
electron@diynoMacBook-Pro catr % cat predict.py | grep argument
parser.add_argument('--path', type=str, help='path to image', required=True)
parser.add_argument('--v', type=str, help='version', default='v3')
parser.add_argument('--checkpoint', type=str, help='checkpoint path', default=None)
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path /image_files/test001.png --v v2
Downloading: "https://github.com/saahiluppal/catr/archive/master.zip" to /Users/electron/.cache/torch/hub/master.zip
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/electron/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [00:13<00:00, 13.6MB/s]
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight389123791.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight389123791.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:26<00:00, 13.0MB/s]
Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232k/232k [00:00<00:00, 792kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 7.09kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 466k/466k [00:00<00:00, 1.06MB/s]
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 570/570 [00:00<00:00, 144kB/s]
Traceback (most recent call last):
  File "/Users/electron/Desktop/catr/predict.py", line 46, in <module>
    image = Image.open(image_path)
  File "/usr/local/lib/python3.9/site-packages/PIL/Image.py", line 2968, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/image_files/test001.png'
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man riding on the back of a brown horse.
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % ls ./image_files/  
test001.png
electron@diynoMacBook-Pro catr % ls /image_files/ 
ls: /image_files/: No such file or directory
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % ls /image_files/test001.png
ls: /image_files/test001.png: No such file or directory
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % ls ./image_files/test001.png
./image_files/test001.png
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % ls
LICENSE         catr_demo.ipynb     engine.py       image_files     predict.py
README.md       configuration.py    finetune.py     main.py         requirements.txt
__pycache__     datasets        hubconf.py      models
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % ls ./image_files 
test001.png test002.jpg
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path /image_files/test002.jpg --v v2 
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Traceback (most recent call last):
  File "/Users/electron/Desktop/catr/predict.py", line 46, in <module>
    image = Image.open(image_path)
  File "/usr/local/lib/python3.9/site-packages/PIL/Image.py", line 2968, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '/image_files/test002.jpg'
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v2
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of men laying on top of a bed.
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weight493084032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:24<00:00, 13.9MB/s]
A group of men laying on top of a beach.
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v1
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
Downloading: "https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth" to /Users/electron/.cache/torch/hub/checkpoints/weights_9348032.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322M/322M [00:23<00:00, 14.4MB/s]
A man in a wetsuit laying on a beach.
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test002.jpg --v v4
Checking for checkpoint.
Traceback (most recent call last):
  File "/Users/electron/Desktop/catr/predict.py", line 32, in <module>
    raise NotImplementedError('No model to chose from!')
NotImplementedError: No model to chose from!
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test001.png --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A man is riding a horse in a field
electron@diynoMacBook-Pro catr % ls ./image_files                                         
test001.png test002.jpg test003.jpg
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % python3 predict.py --path ./image_files/test003.jpg --v v3
Using cache found in /Users/electron/.cache/torch/hub/saahiluppal_catr_master
A group of young people posing for a picture.
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % 
electron@diynoMacBook-Pro catr % cat predict.py | import                                  
zsh: command not found: import
electron@diynoMacBook-Pro catr % cat predict.py | grep import
import torch
from transformers import BertTokenizer
from PIL import Image
import argparse
from models import caption
from datasets import coco, utils
from configuration import Config
import os
electron@diynoMacBook-Pro catr % ls
LICENSE         catr_demo.ipynb     engine.py       image_files     predict.py
README.md       configuration.py    finetune.py     main.py         requirements.txt
__pycache__     datasets        hubconf.py      models
electron@diynoMacBook-Pro catr % 
4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3