はじめに
MNISTやCIFAR10といった画像分類タスクのベンチマークデータセットはtorchvisionで簡単に読み込むことができました。
meta-learningのベンチマークとして用いられるOmniglotやMiniImageNetなどのデータセットに対してもこのようなものはないのかと思い調べていたところtorchmetaというライブラリに出会ったのでご紹介することにしました。
meta-learningにおいては通常の学習に比べdataloaderの設計が少し面倒なのですが、それもtorchmetaではうまいことやってくれるので非常に便利です。
meta-learningについて
ここではN-way K-shotのfew-shot learningにmeta-learningを用いるケースを考え簡単にmeta-learningについてご紹介します。
N-way K-shotのfew-shot learningとは、まず大規模な学習データセットでモデルを学習して、学習後に学習データに含まれていなかった新たなN-classの画像について各classについてK枚ずつの画像のみから再学習しN-class分類器を学習する問題です。通常はK=1やK=5とする場合が多く、少ない画像で学習することになります。
ここからはこれに対してmeta-learningを応用する手順について話します。
データセットはtrain setとtest setに分割します。
学習時にはデータはtaskと呼ばれる塊で管理し、taskにはN*K枚のmeta-train dataとQ枚のmeta-test dataが含まれています。meta-trainとmeta-testは共にtrain setからサンプリングされます。
各taskではNK枚のmeta-trainで学習しQ枚のmeta-testで評価し、meta-test時の勾配を用いてmodelのパラメータを更新します(ここら辺は手法によって若干異なる場合があるのであくまでイメージだと思っていただければと思います)。
これを複数のtaskに対し繰り返すことで、NK枚の画像から学習する方法を学習するという感じになります。
例えばOmniglotという手書き文字データセットは50の言語の計1623文字のデータセットになっており、各文字がclassに対応しています。これは日本語の「あ」「い」「う」などがそれぞれ1つのclassになっているということです。各classには20枚ずつ画像が含まれています。
この画像で学習する際には1623文字から例えば1100文字をtrain setに振り分け、各taskにおいて1100文字の中からN文字をランダムに選び、N*K枚の画像で学習するという流れになります。
torchmeta
ここから本題のtorchmetaの話に移りたいと思います。
torchmetaはOmniglotやMiniImageNetといったデータセットの読み込みだけでなく、上記のようなtaskをサンプリングするdataloaderを提供しています。
インストールは簡単で
pip install torchmeta
とやればできます。
例えば5-way 5-shotでMiniImageNetを読み込みたい場合は
from torchmeta.datasets.helpers import miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader
dataset = miniimagenet("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=1, num_workers=4, shuffle=Flase)
とするとdataloaderができます。実際に学習する際にはミニバッチで学習するのでBatchMetaDataLoaderによってバッチを作成しています(ここではバッチサイズを1としています)。
また実際の学習時にはshuffle=Trueとすべきですが、ここでは後ほど5-way 5-shotでデータが読み込まれていることを確かめるためにFalseとしています。
データの取り出しは
dataiter = iter(dataloader)
data = dataiter.next()
train_images, train_labels = data['train']
とすることででき、ここで取り出されたtrain_imagesの形状は[1, 25, 3, 84, 84]となっています(第一次元がbatch数、第二次元がtaskに含まれる画像の枚数N*K、第三次元が画像のチャンネル数、第四次元、第五次元が画像の縦横のピクセル数)。
import numpy as np
import matplotlib.pyplot as plt
import torchvision
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
plt.show()
imshow(torchvision.utils.make_grid(train_images[0]))
とすると先ほど取り出した画像が次のように表示されます。
左上から順番に5枚ずつ同じclassに含まれる画像が取り出されていることがわかるかと思います(一部よくわからない画像もありますが)。
終わりに
公式の紹介記事とGitHubを参考に動かしてみてとりあえず最低限学習に使えそうな状態にはこれたので紹介記事を書かせていただきました。
GitHubの中にドキュメントも用意されているみたいなので、そちらも併せて参照されると良いと思います。
またmeta-learningについてはかなり緩く書いてしまったので、何かあればご指摘いただけたらと思います。