MLP-Mixer
MLP-Mixer: An all-MLP Architecture for Vision
はじめに
近年、自然言語処理・画像認識・音声認識において、Transformerを使った事前学習モデルの成果が著しいが、一方で古典的なMLP(Multi Layer Perceptron)を改良した驚くほどシンプルなアーキテクチャでTransformerと同等の性能がでることがMLP-Mixer: An all-MLP Architecture for Visionで確認されており、非常に気になるところである。本記事では、MLP-MixerとVision Transformerの推論の精度をImagenetのデータセットで確認してみる。
対象読者
- 人工知能・機械学習・深層学習の概要は知っているという方
- 理論より実装を重視する方
- Pythonを触ったことがある方
- Pytorchを触ったことがある方
目次
VisionTransformer V.S. MLP-Mixer
#1. 論文サマリ
MLP-Mixer: An all-MLP Architecture for Visionが主張しているポイントを
三つに絞るとすると、以下の通りだ。
- CNNやAttentionは画像認識のベンチマークで高精度な記録を出すのに重要だが必須ではない。
- 2つの”まぜまぜ”層によってMLPだけでもSoTAに匹敵する性能を出せる。
- 画像の位置情報(画像パッチ)のまぜまぜ。
- 画像の空間的情報(画像パッチまたぎの情報)のまぜまぜ。
- 訓練と推論にかかるコストもSoTAに匹敵する性能をだせる。
詳細は論文に、概要はAI-Scholarさんの記事"【MLP-Mixer】MLPがCNN,Transformerを超える日"におまかせ。本記事では実際にGoogle Colaboratoryで使う方法を解説する。
結果比較
MLP-MixerとViTどっちがすごいの?って結果だけ知りたいビジネスチックな方のために
結論ファースト。cipher10のテストにおいて、事前学習モデルでそのまま
比較すると全く同精度。再学習モデルでの比較ではViTがちょっとだけ(1%くらい)まだ高精度。
パラメータサイズはMLP-Mixerのが3/4くらいで省エネなので、自分はMLP-Miexer推し。
この研究をきっかけにgMLPが"Pay attention to MLPs"という
Attention意識した論文ではattentionも組み合わせて更に精度が上がっている模様。Attention is Not All You Needというわけだ。
Model | 事前学習モデル正解率(Cipher10) | 再学習後モデル正解率(Cipher10) |
---|---|---|
Vision Transformer(ViT) | 0.10063 | 0.976 |
MLP-Mixer | 0.10063 | 0.9646382 |
#2. 画像認識で比較
まずは、Vision Transformer(ViT)で画像認識。
セットアップ
リポジトリをvision_transformerというディレクトリにクローンして最新化
![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer
!cd vision_transformer && git pull
!pip install -qr vision_transformer/vit_jax/requirements.txt```
利用可能な事前学習モデルを表示
!gsutil ls -lh gs://vit_models/imagenet*
!gsutil ls -lh gs://vit_models/sam
!gsutil ls -lh gs://mixer_models/*
gs://vit_models/imagenet21k+imagenet2012/:
377.57 MiB 2020-11-30T16:17:02Z gs://vit_models/imagenet21k+imagenet2012/R50+ViT-B_16.npz
330.29 MiB 2020-10-29T17:05:52Z gs://vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz
331.4 MiB 2020-10-20T11:48:22Z gs://vit_models/imagenet21k+imagenet2012/ViT-B_16.npz
336.89 MiB 2020-10-20T11:47:36Z gs://vit_models/imagenet21k+imagenet2012/ViT-B_32.npz
334.78 MiB 2021-03-12T09:04:16Z gs://vit_models/imagenet21k+imagenet2012/ViT-B_8.npz
1.13 GiB 2020-10-29T17:08:31Z gs://vit_models/imagenet21k+imagenet2012/ViT-L_16-224.npz
1.14 GiB 2020-10-20T11:53:44Z gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz
1.14 GiB 2020-10-20T11:50:56Z gs://vit_models/imagenet21k+imagenet2012/ViT-L_32.npz
gs://vit_models/imagenet21k/:
450.23 MiB 2021-01-20T14:12:43Z gs://vit_models/imagenet21k/R26+ViT-B_32.npz
439.85 MiB 2020-11-30T10:10:15Z gs://vit_models/imagenet21k/R50+ViT-B_16.npz
1.31 GiB 2021-01-20T14:11:54Z gs://vit_models/imagenet21k/R50+ViT-L_32.npz
393.69 MiB 2020-10-22T21:38:39Z gs://vit_models/imagenet21k/ViT-B_16.npz
400.01 MiB 2020-11-02T08:30:56Z gs://vit_models/imagenet21k/ViT-B_32.npz
393.72 MiB 2021-03-10T13:28:28Z gs://vit_models/imagenet21k/ViT-B_8.npz
2.46 GiB 2020-11-03T10:46:11Z gs://vit_models/imagenet21k/ViT-H_14.npz
1.22 GiB 2020-11-09T14:39:51Z gs://vit_models/imagenet21k/ViT-L_16.npz
1.23 GiB 2020-11-02T08:35:10Z gs://vit_models/imagenet21k/ViT-L_32.npz
TOTAL: 17 objects, 14306096550 bytes (13.32 GiB)
330.3 MiB 2021-07-13T19:39:09Z gs://vit_models/sam/ViT-B_16.npz
336.61 MiB 2021-07-13T19:39:10Z gs://vit_models/sam/ViT-B_32.npz
1.13 GiB 2021-07-13T19:39:38Z gs://vit_models/sam/ViT-L_16.npz
1.14 GiB 2021-07-13T19:39:38Z gs://vit_models/sam/ViT-L_32.npz
TOTAL: 4 objects, 3143025464 bytes (2.93 GiB)
6 B 2021-06-28T13:07:12Z gs://mixer_models/sam_$folder$
gs://mixer_models/imagenet1k/:
228.47 MiB 2021-05-05T14:09:01Z gs://mixer_models/imagenet1k/Mixer-B_16.npz
794.29 MiB 2021-05-05T14:09:02Z gs://mixer_models/imagenet1k/Mixer-L_16.npz
gs://mixer_models/imagenet21k/:
289.61 MiB 2021-05-05T14:09:11Z gs://mixer_models/imagenet21k/Mixer-B_16.npz
875.78 MiB 2021-05-05T14:09:12Z gs://mixer_models/imagenet21k/Mixer-L_16.npz
gs://mixer_models/sam/:
228.47 MiB 2021-06-28T13:08:09Z gs://mixer_models/sam/Mixer-B_16.npz
230.04 MiB 2021-06-28T13:08:08Z gs://mixer_models/sam/Mixer-B_32.npz
TOTAL: 7 objects, 2775222110 bytes (2.58 GiB)
ViTとMixerをダウンロード
model_name = 'ViT-B_32' #@param ["ViT-B_32", "Mixer-B_16"]
if model_name.startswith('ViT'):
![ -e "$model_name".npz ] || gsutil cp gs://vit_models/imagenet21k/"$model_name".npz .
if model_name.startswith('Mixer'):
![ -e "$model_name".npz ] || gsutil cp gs://mixer_models/imagenet21k/"$model_name".npz .
Colab上でモデルをプルダウンで選択できて便利。まずは、ViTでいってみよう。
TPU利用できるか確認(→今のところ無料プランではできない)。GPUでいこう。
import os
if 'google.colab' in str(get_ipython()) and 'COLAB_TPU_ADDR' in os.environ:
import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print('Connected to TPU.')
else:
print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
No TPU detected. Can be changed under "Runtime/Change runtime type".
ログ設定とデバイス確認
from absl import logging
import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import tqdm
logging.set_verbosity(logging.INFO)
# 利用可能なデバイスの数を表示
jax.local_devices()
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker:
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
[GpuDevice(id=0, process_index=0)]
GPUは使えるよ。
Colabからソースを確認しながら実行する小技:画面右スプリットエディタでいくつかのコードファイルを開く。
from google.colab import files
files.view('vision_transformer/vit_jax/configs/common.py')
files.view('vision_transformer/vit_jax/configs/models.py')
files.view('vision_transformer/vit_jax/checkpoint.py')
files.view('vision_transformer/vit_jax/input_pipeline.py')
files.view('vision_transformer/vit_jax/models.py')
files.view('vision_transformer/vit_jax/momentum_clip.py')
files.view('vision_transformer/vit_jax/train.py')
モジュールのインポート
import sys
if './vision_transformer' not in sys.path:
sys.path.append('./vision_transformer')
%load_ext autoreload
%autoreload 2
from vit_jax import checkpoint
from vit_jax import input_pipeline
from vit_jax import utils
from vit_jax import models
from vit_jax import momentum_clip
from vit_jax import train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config
Cipher10/100の画像扱うためのHelper
# Helper functions for images.
labelnames = dict(
# https://www.cs.toronto.edu/~kriz/cifar.html
cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
# https://www.cs.toronto.edu/~kriz/cifar.html
cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')
)
def make_label_getter(dataset):
"""Returns a function converting label indices to names."""
def getter(label):
if dataset in labelnames:
return labelnames[dataset][label]
return f'label={label}'
return getter
def show_img(img, ax=None, title=None):
"""Shows a single image."""
if ax is None:
ax = plt.gca()
ax.imshow(img[...])
ax.set_xticks([])
ax.set_yticks([])
if title:
ax.set_title(title)
def show_img_grid(imgs, titles):
"""Shows a grid of images."""
n = int(np.ceil(len(imgs)**.5))
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
for i, (img, title) in enumerate(zip(imgs, titles)):
img = (img + 1) / 2 # Denormalize
show_img(img, axs[i // n][i % n], title)
データセットのロード
dataset = 'cifar10'
batch_size = 512
config = common_config.with_dataset(common_config.get_config(), dataset)
num_classes = input_pipeline.get_dataset_info(dataset, 'train')['num_classes']
config.batch = batch_size
config.pp.crop = 224
INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar10/3.0.2
INFO:absl:Load dataset info from /tmp/tmpmx7kqa7vtfds
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
データセットの設定については、右記のinput_pipeline.pyを参照
ds_train = input_pipeline.get_data_from_tfds(config=config, mode='train')
ds_test = input_pipeline.get_data_from_tfds(config=config, mode='test')
del config # データセットをインスタンス化するためにのみ必要です。
INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar10/3.0.2
INFO:absl:Load dataset info from /tmp/tmp5qpln01btfds
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Generating dataset cifar10 (/root/tensorflow_datasets/cifar10/3.0.2)
Downloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...
Dl Completed...: 100%
1/1 [00:20<00:00, 18.14s/ url]
Dl Size...: 100%
162/162 [00:20<00:00, 6.62 MiB/s]
Extraction completed...: 100%
1/1 [00:20<00:00, 20.50s/ file]
INFO:absl:Downloading https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz into /root/tensorflow_datasets/downloads/cs.toronto.edu_kriz_cifar-10-binaryODHPtIjLh3oLcXirEISTO7dkzyKjRCuol6lV8Wc6C7s.tar.gz.tmp.60fc73cd82d24890985164371806add3...
INFO:absl:Generating split train
Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteEIKKBP/cifar10-train.tfrecord
100%
49999/50000 [00:00<00:00, 122648.19 examples/s]
INFO:absl:Done writing /root/tensorflow_datasets/cifar10/3.0.2.incompleteEIKKBP/cifar10-train.tfrecord. Shard lengths: [50000]
INFO:absl:Generating split test
Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteEIKKBP/cifar10-test.tfrecord
100%
9999/10000 [00:00<00:00, 46586.49 examples/s]
INFO:absl:Done writing /root/tensorflow_datasets/cifar10/3.0.2.incompleteEIKKBP/cifar10-test.tfrecord. Shard lengths: [10000]
INFO:absl:Skipping computing stats for mode ComputeStatsMode.SKIP.
INFO:absl:Constructing tf.data.Dataset for split train[:98%], from /root/tensorflow_datasets/cifar10/3.0.2
Dataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Reusing dataset cifar10 (/root/tensorflow_datasets/cifar10/3.0.2)
INFO:absl:Constructing tf.data.Dataset for split test, from /root/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
説明用にテスト画像を一括して取得
batch = next(iter(ds_test.as_numpy_iterator()))
# 形の意味: [num_local_devices, local_batch_size, h, w, c]
batch['image'].shape
(1, 512, 224, 224, 3)
画像サンプルを表示
images, labels = batch['image'][0][:9], batch['label'][0][:9]
titles = map(make_label_getter(dataset), labels.argmax(axis=1))
show_img_grid(images, titles)
訓練用画像も確認していく。
# 画像がどのように切り取られ、どのように拡大されるかを確認
# 右側のエディタでinput_pipeline.get_data()をチェックすると、
# 画像の前処理がどのように異なるかがわかる。
batch = next(iter(ds_train.as_numpy_iterator()))
images, labels = batch['image'][0][:9], batch['label'][0][:9]
titles = map(make_label_getter(dataset), labels.argmax(axis=1))
show_img_grid(images, titles)
事前学習モデルロード
model_config = models_config.MODEL_CONFIGS[model_name]
model_config
classifier: token
hidden_size: 768
name: ViT-B_32
patches:
size: !!python/tuple
- 32
- 32
representation_size: null
transformer:
attention_dropout_rate: 0.0
dropout_rate: 0.0
mlp_dim: 3072
num_heads: 12
num_layers: 12
ViTがロードされていて、Transformerが12層あることが確認できるよ。
続いて、パラメータ初期化
# モデルの定義を読み込み、ランダムなパラメータを初期化。
# また、モデルをXLAにコンパイルします(初回は数分かかる)。
if model_name.startswith('Mixer'):
model = models.MlpMixer(num_classes=num_classes, **model_config)
else:
model = models.VisionTransformer(num_classes=num_classes, **model_config)
variables = jax.jit(lambda: model.init(
jax.random.PRNGKey(0),
# Discard the "num_local_devices" dimension of the batch for initialization.
batch['image'][0, :1],
train=False,
), backend='cpu')()
チェックポイントのロード
# チェックポイントのロードと変換。
# 実際に事前学習したモデルの結果をロードしますが、同時に
# 最終層の変更や、位置埋め込みのサイズ変更など、パラメータを少し変更します。
# 位置埋め込みを変更します。
# 詳細については、コードと論文のメソッドを参照のこと。
params = checkpoint.load_pretrained(
pretrained_path=f'{model_name}.npz',
init_params=variables['params'],
model_config=model_config,
)
INFO:absl:Inspect extra keys:
{'pre_logits/bias', 'pre_logits/kernel'}
INFO:absl:load_pretrained: drop-head variant
ここまでで、すべてのデータはホストメモリにある。
# 配列をデバイスに複製する。
# これにより、pytree paramsに含まれるすべての配列がShardedDeviceArrayになる。
# 同じデータがすべてのローカルデバイスに複製される。
# 単一のGPUの場合は、単にデータをデバイスに移動する。
params_repl = flax.jax_utils.replicate(params)
print('params.cls:', type(params['head']['bias']).__name__,
params['head']['bias'].shape)
print('params_repl.cls:', type(params_repl['head']['bias']).__name__,
params_repl['head']['bias'].shape)
params.cls: DeviceArray (10,)
params_repl.cls: _ShardedDeviceArray (1, 10)
/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:391: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
"jax.host_count has been renamed to jax.process_count. This alias "
/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:378: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
"jax.host_id has been renamed to jax.process_index. This alias "
そして、モデルのフォワードパスの呼び出しを、利用可能なすべてのデバイスにマッピングする。
vit_apply_repl = jax.pmap(lambda params, inputs: model.apply(
dict(params=params), inputs, train=False))
正解率を計測する用に関数を定義する。
def get_accuracy(params_repl):
"""Returns accuracy evaluated on the test set."""
good = total = 0
steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
predicted = vit_apply_repl(params_repl, batch['image'])
is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
good += is_same.sum()
total += len(is_same.flatten())
return good / total
とりあえず、ViTを事前学習モデルのまま再学習なしにcipher10で正解率計測してみる。
# 再学習していない状態での精度
get_accuracy(params_repl)
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
100%|██████████| 19/19 [01:25<00:00, 4.49s/it]
DeviceArray(0.10063734, dtype=float32)
10%という結果。そりゃあそうか。再学習してないもの。
ということで再学習してみる。
total_steps = 100
warmup_steps = 5
decay_type = 'cosine'
grad_norm_clip = 1
#accum_steps = 8
accum_steps = 32
base_lr = 0.03
# 詳細は、右側のエディタでtrain.make_update_fnを確認のこと。
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
apply_fn=model.apply, accum_steps=accum_steps, lr_fn=lr_fn)
# モメンタムで半分の精度を使用しメモリを節約するグラデーションクリッピングを利用。
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)
# ドロップアウトのためPRNGsを初期化
update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))
/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:391: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
"jax.host_count has been renamed to jax.process_count. This alias "
/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:378: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
"jax.host_id has been renamed to jax.process_index. This alias "
TPUで20分程度かかるらしいがGPUでも30分くらいだった。
losses = []
lrs = []
for step, batch in zip(
tqdm.trange(1, total_steps + 1),
ds_train.as_numpy_iterator(),
):
opt_repl, loss_repl, update_rng_repl = update_fn_repl(
opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
losses.append(loss_repl[0])
lrs.append(lr_fn(step))
plt.plot(losses)
plt.figure()
plt.plot(lrs)
さて、再学習後の正解率は
# 精度計測
print(model_name)
print(get_accuracy(opt_repl.target))
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
ViT-B_32
100%|██████████| 19/19 [00:50<00:00, 2.66s/it]
0.9766653
97.6%と結構高い。さすがViT!!!
つづいてMLP-Mixerで画像認識。変えるのはここだけ。
モデルのサイズがまず3/4くらいになったのがうれしい。
まず、事前学習モデルのままの精度はこちら。
# 再学習していない状態での精度
get_accuracy(params_repl)
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
100%|██████████| 19/19 [02:37<00:00, 8.28s/it]
DeviceArray(0.10063734, dtype=float32)
再学習の損失をプロットするとこんな感じ
losses = []
lrs = []
# GPUで35分程度かかる
for step, batch in zip(
tqdm.trange(1, total_steps + 1),
ds_train.as_numpy_iterator(),
):
opt_repl, loss_repl, update_rng_repl = update_fn_repl(
opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
losses.append(loss_repl[0])
lrs.append(lr_fn(step))
plt.plot(losses)
plt.figure()
plt.plot(lrs)
再学習後の精度はこちら。
# 精度計測
print(model_name)
print(get_accuracy(opt_repl.target))
INFO:absl:Load dataset info from /root/tensorflow_datasets/cifar10/3.0.2
Mixer-B_16
100%|██████████| 19/19 [02:30<00:00, 7.91s/it]
0.9646382
長い記事を読んでいただき感謝です。
参考文献
-Vision Transformer / MLP-Mixer
著者
ツイッターもやってます。
@keiji_dl
その他オススメ
人工知能Webアプリ開発入門
画像認識をViT(Vision Transformer)を用いて行うWebアプリをPython/Pytorch Lightning Flask/Jinja2/Bootstrap/JQuery/CSS/HTMLなどをフルスタックに使って7ステップ2時間で解説する動画を開設しました。
Transformer 101本ノック(Kindle Unlimited)。
ソースコードへのリンクも手に入りますのでぜひご一読下さい。