11
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Model-Agnostic Meta-Learning (MAML)を実装してみた

Last updated at Posted at 2020-07-16

はじめに

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks [Finn et al., ICML'17]の再現実装を行ました。
著者実装自体はここにあるのですが、自力で実装してみようというのが今回の主旨となります。

論文について

Model-Agnostic Meta-Learningというタイトルの通り、モデルの構造に関わらず適用でき、また通常の分類問題や回帰問題だけでなく強化学習などにおいても適用できるgeneralなmeta-learningの枠組みを提案したのがこの論文です。
meta-learningはfew-shot learningの文脈で登場する手法です。
CVPR2020で行われたTowards Annotation-Efficient Learning: Few-Shot, Self-Supervised, and Incremental Learning ApproachesというTutorialによるとfew-shot learningには大きく分けて以下の4つの手法があるのですが、

  • Metric learning
  • Meta-learning with memory modules
  • Optimization based meta-learning
  • Learn to predict model parameters

MAMLはOptimization based meta-learningに該当する手法であり、またこのカテゴリの代表例となっています。
MAMLはfew-shot learningに関する手法の性能を比較する際のbaselineとしてよく用いられるだけでなく、MAMLを拡張した手法が多く提案されるなど、few-shot learningやmeta-learningに関する論文を読んでいく上でとても重要な手法だと考えられます。

手法について

MAMLのアイデアは至ってシンプルであり、一言で言うと「パラメータの更新式を工夫する」というものです。
具体的には

\theta'_i = \theta - \alpha\nabla_\theta L_{T_i}(f_\theta) \\
\theta \leftarrow \theta - \beta\nabla_\theta\sum_{T_i\sim p(T)}L_{T_i}(f_{\theta'_i})

という2つの式によりモデルのパラメータ$\theta$を更新します。$\theta'_i$というのはタスク$T_i$についてfine-tuningしたときのパラメータであり、1つ目の式はfine-tuningの更新式を表しています。
タスクというのはエピソードとも呼ばれるmeta-learning特有のデータ集合のことです。それぞれのタスクにはtrain dataとquery dataが含まれており、画像分類における5クラス分類の5-shotという設定だと

  • train : fine-tuningのためのデータ。5クラス*5-shotの25枚の画像。1つ目の更新式で用いる。
  • query : fine-tuningの出来を評価するためのデータ。5クラスのそれぞれについて15枚ずつの画像を用意することが多い。2つ目の更新式で用いる。

という構成になっており、タスク$T_1$にはクラス1~5、タスク$T_2$にはクラス6~10といったように、それぞれのタスクには異なるクラスの組み合わせが含まれています。
これらのタスクに対するfine-tuningの出来を2つ目の更新式で評価することにより、どのようなタスクに対しても5-shotのデータでうまくfine-tuningできるようなモデルのパラメータ$\theta$を学習するというのがMAMLの中身になります。

更新式自体がMAMLの中身ということになるのでmodel-agnostic、つまりどのようなモデルに対しても適用できる手法ということになり、またmeta-learningのための新たなパラメータを必要としないということにもなり、これらがMAMLの利点となっています。

実装について

実装したもののソースコードはここに置いてあるので、要点をかいつまんで説明していきたいと思います。

まずmodelについて

maml.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )

def functional_conv(x, weights, biases, bn_weights, bn_biases):
    x = F.conv2d(x, weights, biases, padding=1)
    x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x

class MAML(nn.Module):
    def __init__(self):
        super(MAML, self).__init__()
        self.conv1 = conv_block(3, 32)
        self.conv2 = conv_block(32, 32)
        self.conv3 = conv_block(32, 32)
        self.conv4 = conv_block(32, 32)
        self.logits = nn.Linear(800, 5)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.x.view(x.size(0), -1)
        return self.logits(x)

    def adaptation(self, x, weights):
        for block in [1, 2, 3, 4]:
            x = functional_conv(x, weights[f'conv{block}.0.weight'], weights[f'conv{block}.0.bias'],
                                weights.get(f'conv{block}.1.weight'), weights.get(f'conv{block}.1.bias'))
        
        x = x.view(x.size(0), -1)
        return F.linear(x, weights['logits.weight'], weights['logits.bias'])

通常のmodelの定義とは別にadaptationと言うメソッドを定義していますが、これはfine-tuningで(もとの重みを維持したまま)一時的に重みを更新してlossを計算するときに必要なものです。

重みの更新は次のようになっています。

train.py
weights = OrderedDict(model.named_parameters())

# k-shotでadaptation
input_x = x_train[idx].to(device)
input_y = y_train[idx].to(device)
for iter in range(train_step):
    logits = model.adaptation(input_x, weights)
    loss = loss_fn(logits, input_y)
    gradients = torch.autograd.grad(loss, weights.values(), create_graph=train)
    weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(weights.items(), gradients))

# queryで評価
input_x = x_val[idx].to(device)
input_y = y_val[idx].to(device)
logits = model.adaptation(input_x, weights)
loss = loss_fn(logits, input_y)
if train:
    model.train()
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

k-shotのデータでfine-tuningし、その重みを用いてqueryデータに対するlossを計算して、このlossで重みを更新しています。下から2行目のところで

loss.backward(retain_graph=True)

とすることで、1つ目の式のfine-tuningの際に構築した計算グラフを2つ目の式の勾配を求めるところでも用いることができるようにしています。

実験

論文では回帰、分類、強化学習のそれぞれについて実験が行われていましたが、ここではminiImageNet[Ravi et al., ICLR'17]というデータセットを用いた5-way 5-shotの分類問題について実験してみました。
ハイパーパラメータ等の設定についてはソースコードを参照してください。また、データのロードにはtorchmetaというライブラリを使用しています。torchmetaの使い方についてはこちらに簡単にまとめています。
論文では60000iterationの学習でaccuracyは$63.11\pm 0.92%$となっています。この$\pm$は$95%$信頼区間を表しているとのことですが、実験を何回繰り返して算出したものなのかは記されていなかったので、ここら辺の設定は論文の実装とは異なってしまっている可能性があります。

ここでは60000iteration学習し、その後テストデータセットから1000個のタスクを作りそれぞれに対するaccuracyを求め、その1000個のaccuracdyからaccuracyの$95%$信頼区間を算出するということを行いました。その結果accuracyは$61.03\pm 0.41%$となり論文の結果にはわずかに届かない性能となりました。
また75000iteraion学習した場合は$62.32\pm 0.40%$となり、論文の結果とほぼ同等の性能となりました。

終わりに

MAMLを実際に実装してみることで、この手法がいかにシンプルでかつ強力な手法であるかを実感することができました。
実際に動かしてみた所感としては、この手法は更新式に二階微分を含んでしまうので計算グラフのところがややこしいと感じました。バッチサイズは論文に合わせて2としていたのですが、これを大きくすると計算グラフが大量に作られメモリが足りなくなるという問題が発生してしまったので、ここの部分でもっと上手い実装があったのかなと思い、これが今後の(個人的な)課題だと感じました。

11
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?