2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Show attend and tellへ入門する

Last updated at Posted at 2020-04-01

0. 概要

XAIで説明可能なAIが語られるようになってから説明文の生成+画像の何処に着目したかを生成してくれるHybridなShow attend and tellというものが生まれた。
今回はをこれを使ってみる。

1. Show, Attend and Tell

pytorch==1.2.0の環境で動作させる。

1.1. Installation

まずはGitからファイルを持ってくる。

$ git clone https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning
$ cd a-PyTorch-Tutorial-to-Image-Captioning

a-PyTorch-Tutorial-to-Image-Captioningを作業ディレクトリとする。
次に、学習で必要なファイルのダウンロードを行う。

$ mkdir coco2014
$ cd coco2014
$ aria2c -x10 -s10 -k1M http://images.cocodataset.org/zips/train2014.zip http://msvocds.blob.core.windows.net/coco2014/train2014.zip
$ aria2c -x10 http://images.cocodataset.org/zips/val2014.zip
$ aria2c -x10 http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip
$ unzip train2014.zip val2014.zip caption_datasets.zip

1.2. Convert dataset

ダウンロードしたファイルをPytorchで効率的に読み込める形式に変形する。
以下のファイルを開いてパスを設定しなおす。

create_input_files.py
from utils import create_input_files

if __name__ == '__main__':
    # Create input files (along with word map)
    create_input_files(dataset='coco',
                       karpathy_json_path='./coco2014/dataset_coco.json',
                       image_folder='./coco2014/',
                       captions_per_image=5,
                       min_word_freq=5,
                       output_folder='./coco2014/',
                       max_len=50)

次に実行。

$ python create_input_files.py

こんな感じのが表示されたら15分程待つ。
image.png

1.3. Training

これで学習環境が整った。まず以下のファイルを開いてパスを設定。

train.py
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

# Data parameters
data_folder = './coco2014/'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
...

次に、データセットの形式に変更があるので、以下のように何カ所かソースコードを書き換える。

train.py
...
scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
# scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
...
loss = criterion(scores.data, targets.data)
# loss = criterion(scores, targets)
...

これで動作するはずである。

$ python train.py

image.png

後は30分くらい学習させておく。

1.4. Inference

pytorchの互換性の問題で以下を使わないと多分上手く動かない。

$ pip uninstall scipy
$ pip install scipy==1.2.0
$ pip uninstall pillow
$ pip install pillow==6.2.1

後は、以下で実行する。

$ python caption.py  --img='<IMG_PATH>' --model='./BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' --word_map='/mnt/exthd1/coco2014/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' --beam_size=5

1時間ほど学習させた結果を張っておく。

image.png

Planeでしっかり飛行機を見ているのが凄い!

この後、CUBデータセットで学習させてみても上手く動いた。
また強化学習も交えてみるとかなりの精度向上が見られた。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?