LoginSignup
1
2

More than 1 year has passed since last update.

FacebookのDeiT(Data-efficient image Transformers)

Last updated at Posted at 2021-06-06

概要

Deitもtransformerベースの事前学習モデルだが、ViTやEfficientNetなどと比較すると、より少ないデータ・少ないパラメタで学習が行えるように「蒸留トークン」という考え方を導入している。詳細は、原著を参照のこと。

対象読者

  • 人工知能・機械学習・深層学習の概要は知っているという方
  • 理論より実装を重視する方
  • Pythonを触ったことがある方

目次

1.31本目-ライブラリのセットアップ
2.32本目-使用する画像データサンプルの取得
3.33本目-事前学習モデルのロード
4.34本目-前処理
5.35本目-推論(画像認識)
6.参考文献
7.著者
8.参考動画

1. 31本目 ライブラリのセットアップ

  • 質問: ColabでDeiTを利用するために必要なモジュールをインストールせよ。

  • 回答: huggingfaceのtransformersのみ入れれば、依存するものは入ってくる。ver.は結構大事なのでつけておくことを推奨する。

pip install transformers==4.6.1

2. 32本目 使用する画像データサンプルの取得

  • 質問: 猫画像のサンプルを取得しpillowのImageにロードせよ。

  • 回答: https://cocodataset.org/#explore を利用する。pillowとrequestsを用いることでインターネット上に公開されている画像は簡単にpillowのimageにロード可能。

from PIL import Image
import requests

url = 'http://farm8.staticflickr.com/7250/7520201840_3e01349e3f_z.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

image.png

3. 33本目 事前学習モデルのロード

  • 質問: DeiT(Data-efficient image Transformers)の事前学習モデル'facebook/deit-base-distilled-patch16-224'をDeiTFeatureExtractor, DeiTForImageClassificationを利用しロードせよ。

  • 回答: feature_extractorとmodelにロードする。feature_extractorは、画像認識の前処理を担当させ、modelの方で推論(画像認識)タスクを解かせる。

from transformers import DeiTFeatureExtractor, DeiTForImageClassification
feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224')

4. 34本目 前処理

  • 質問: 32本目でロードした画像を前処理にかけ、DeiTの事前学習モデルへ渡せる形式に変換せよ。

  • 回答: feature_extractorを利用する。


inputs = feature_extractor(images=image, return_tensors="pt")

前処理された後の画像はmatplotlibで簡単に確認できる。

import matplotlib.pyplot as plt
plt.imshow(inputs['pixel_values'][0][0])

image.png

5. 35本目 画像認識

  • 質問: 34本目で前処理をしたinputsを入力し、DeiTForImageClassificationを利用して画像認識を行え。認識結果は索引を変換したテキスト(キャプション・ラベル)として出力せよ。

  • 回答: 推論結果はlogitsで返却されるため、argmaxをとって、最も大きい値の索引を返却することで、キャプションに変換できる


outputs = model(**inputs)
logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()

print(f'Predicted Class:{model.config.id2label[predicted_class_idx]}')
  • 結果

Predicted Class:car wheel

猫と認識してくれなかったのだが、横に置いてある自転車の車輪の方を車の車輪として
認識したようだ。個人的にはCLIPの方が実用的に思えた。

11. 参考文献

-原著論文

-Facebook AI Blog

12. 著者

ツイッターでPython/numpy/pandas/pytorch関連の有益なツイートを配信してます。

@keiji_dl

13. 参考動画

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2