概要
MAEのコラ画像生成の能力を測るため,軽い実験をしてみます.ちなみにMAEとはBERTのように入力画像のMASK部分を復元するようなモデルです.
用意
facebook research公式のこちらのスクリプトを使います.
ソース
- 元のソースコードでは入力画像をランダムマスクするため,これをマニュアルにしましょう.
- 今回は上に張り付ける画像の境界部分を手作業で指示します.
import types
def random_masking(self, x, mask_ratio):
N, L, D = x.shape # batch, length, dim
######## マスクするパッチを手作業で指示します.########
mask_patch = [19,20,21,22,23,24,33,47,61,75,89,103,104,107,108,94,80,66,52,38,93,79,37,90]
####################################################
len_keep = 14**2 - len(mask_patch)
noise = torch.zeros(N, L, device=x.device) # noise in [0, 1]
noise[:, mask_patch] = 1.0
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
model_mae_gan.random_masking = types.MethodType(random_masking, model_mae_gan)
- マスクパッチの指示は以下のように番号を参照します.



結果
上記のコードを"Run MAE on the image"の直前に実行してから,"Run MAE on the image"を実行してみます.
グレーでマスクされた部分がモデルにより復元されています.


なんとも言えない結果ですね.別の画像でもやってみましょう.


ちょっと位置調整をします.


画像はDiamond online様と自由民主党様よりお借りしました.
まとめ
本記事はfacebook research様が公開されているMAEの学習モデルのコラ画像生成の能力を見てみました.結果は...実用できるほどではありませんね...