0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

メタ学習( MAML )のある実装。

0
Last updated at Posted at 2024-10-24

勉強した動機

今まで、通常の機械学習で音声合成、音声認識、機械翻訳や画像キャプショニングで非自己回帰型のモデルについて学習を行い精度を測定しました。精度を上げるためには、学習データと学習パラメータを増やす方法が有効です。しかし、大企業や AI の有力スタートアップでもないかぎり、学習データ作成に投資することはできません。そこで、メタ学習の少ないデータで学習できるという特徴を確かめるべく、MAML を実装し、5-way 10-shot の画像分類について実際に学習を行い精度を測定しました。5-way とは、5クラス画像分類の5 で、10-shotとは、一回に学習させる画像数が10枚ということです。結果は、少ないデータ量でそこそこの精度が得られました。この頃実感しましたが、計算時間とメモリは必要です。

参考にさせていただいたページ

実装をするにあたり、勉強しました。インターネットページで勉強しました。特に、

のページ。この実装が置いてある

は、かなり参考になりました。この他に、

の実装を参考にさせていただきました。論文の解説は

を参考にさせていただきました。感謝いたします。

のページも参考にさせていただきました。感謝いたします。

問題の定式化と学習用データセット

学習させる問題は、5クラスの画像分類問題です。100クラスに分類された画像を、5クラス*20タスクと考えます。メタ学習は学習の仕方を学習すると言いますが、20種類の5クラス分類問題を学習することにより、学習の仕方を学習するわけです

メタ学習でも、MAMLというモデルなので、色々な学習モデルに対してプログラムを作ることが可能です。その代わり、model.forward に加えて、model.adaptation という、入力と学習パラメータを与えて出力を計算する関数を作る必要があります。感情分析の model.forward と model.adpatation ですが、実例を掲載させていただきます。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    '''
    位置エンコーディング (Positional encoding)
    dim_embedding: 埋込み次元
    max_len      : 入力の最大系列長
    temperature  : 温度定数
    '''
    def __init__(self, dim_embedding: int,
                 max_len: int=5000, temperature=10000):
        super().__init__()

        assert dim_embedding % 2 == 0

        dim_t = torch.arange(0, dim_embedding, 2)
        dim_t = dim_t / dim_embedding
        dim_t = temperature ** dim_t

        x_encoding = torch.arange(max_len).unsqueeze(1)
        x_encoding = x_encoding / dim_t

        # 位置情報を保持するテンソル

        pe = torch.zeros(max_len, dim_embedding)
        pe[:, ::2] = x_encoding.sin()
        pe[:, 1::2] = x_encoding.cos()

        # PEをメモリに保存
        self.register_buffer('pe', pe)

    '''
    位置エンコーディングの順伝播
    x: 位置エンコーディングを埋め込む対象のテンソル,
       [バッチサイズ, 系列長, 埋め込み次元]
    '''
    def forward(self, x: torch.Tensor):
        seq = x.shape[1]
        x = x + self.pe[:seq]

        return x    

class SelfAttention(nn.Module):
    '''
    自己アテンション
    dim_hidden: 入力特徴量の次元
    num_heads : マルチヘッドアテンションのヘッド数
    qkv_bias  : クエリなどを生成する全結合層のバイアスの有無
    '''
    def __init__(self, dim_hidden: int, num_heads: int,
                 qkv_bias: bool=False):
        super().__init__()

        # 特徴量を各ヘッドのために分割するので、
        # 特徴量次元をヘッド数で割り切れるか検証
        assert dim_hidden % num_heads == 0

        self.num_heads = num_heads

        # ヘッド毎の特徴量次元
        dim_head = dim_hidden // num_heads

        # ソフトマックスのスケール値
        self.scale = dim_head ** -0.5

        # ヘッド毎にクエリ、キーおよびバリューを生成するための全結合層
        self.proj_in = nn.Linear(
            dim_hidden, dim_hidden * 3, bias=qkv_bias)

        # 各ヘッドから得られた特徴量を一つにまとめる全結合層
        self.proj_out = nn.Linear(dim_hidden, dim_hidden)

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor, attention_ids:torch.Tensor):
        bs, ns = x.shape[:2]

        qkv = self.proj_in(x)

        # view関数により
        # [バッチサイズ, 特徴量数, QKV, ヘッド数, ヘッドの特徴量次元]
        # permute関数により
        # [QKV, バッチサイズ, ヘッド数, 特徴量数, ヘッドの特徴量次元]
        qkv = qkv.view(
            bs, ns, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

        # クエリ、キーおよびバリューに分解
        q, k, v = qkv.unbind(0)

        # クエリとキーの行列積とアテンションの計算(今回マスクは不使用)
        # attnは[バッチサイズ, ヘッド数, 特徴量数, 特徴量数]
        attn = q.matmul(k.transpose(-2, -1))
        attn = attn + torch.unsqueeze( torch.unsqueeze( attention_ids * -1e9 , dim = 1 ), dim = 1 )
        attn = (attn * self.scale).softmax(dim=-1)

        # アテンションとバリューの行列積によりバリューを収集
        # xは[バッチサイズ, ヘッド数, 特徴量数, ヘッドの特徴量次元]
        x = attn.matmul(v)

        # permute関数により
        # [バッチサイズ, 特徴量数, ヘッド数, ヘッドの特徴量次元]
        # flatten関数により全てのヘッドから得られる特徴量を連結して、
        # [バッチサイズ, 特徴量数, ヘッド数 * ヘッドの特徴量次元]
        x = x.permute(0, 2, 1, 3).flatten(2)
        x = self.proj_out(x)

        return x

class fSelfAttention(nn.Module):
    '''
    自己アテンション
    dim_hidden: 入力特徴量の次元
    num_heads : マルチヘッドアテンションのヘッド数
    qkv_bias  : クエリなどを生成する全結合層のバイアスの有無
    '''
    def __init__(self, dim_hidden: int, num_heads: int,
                 qkv_bias: bool=False):
        super().__init__()

        # 特徴量を各ヘッドのために分割するので、
        # 特徴量次元をヘッド数で割り切れるか検証
        assert dim_hidden % num_heads == 0

        self.num_heads = num_heads

        # ヘッド毎の特徴量次元
        dim_head = dim_hidden // num_heads

        # ソフトマックスのスケール値
        self.scale = dim_head ** -0.5

        # ヘッド毎にクエリ、キーおよびバリューを生成するための全結合層
        #self.proj_in = nn.Linear(
        #    dim_hidden, dim_hidden * 3, bias=qkv_bias)

        # 各ヘッドから得られた特徴量を一つにまとめる全結合層
        #self.proj_out = nn.Linear(dim_hidden, dim_hidden)

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor, attention_ids: torch.Tensor, i, weights ):
        bs, ns = x.shape[:2]

        #qkv = self.proj_in(x)
        qkv = F.linear(x, weights['layers.' + str(i) + '.attention.proj_in.weight'], bias = None)


        # view関数により
        # [バッチサイズ, 特徴量数, QKV, ヘッド数, ヘッドの特徴量次元]
        # permute関数により
        # [QKV, バッチサイズ, ヘッド数, 特徴量数, ヘッドの特徴量次元]
        qkv = qkv.view(
            bs, ns, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

        # クエリ、キーおよびバリューに分解
        q, k, v = qkv.unbind(0)

        # クエリとキーの行列積とアテンションの計算(今回マスクは不使用)
        # attnは[バッチサイズ, ヘッド数, 特徴量数, 特徴量数]
        attn = q.matmul(k.transpose(-2, -1))
        attn = attn + torch.unsqueeze( torch.unsqueeze( attention_ids * -1e9 , dim = 1 ), dim = 1 )
        attn = (attn * self.scale).softmax(dim=-1)

        # アテンションとバリューの行列積によりバリューを収集
        # xは[バッチサイズ, ヘッド数, 特徴量数, ヘッドの特徴量次元]
        x = attn.matmul(v)

        # permute関数により
        # [バッチサイズ, 特徴量数, ヘッド数, ヘッドの特徴量次元]
        # flatten関数により全てのヘッドから得られる特徴量を連結して、
        # [バッチサイズ, 特徴量数, ヘッド数 * ヘッドの特徴量次元]
        x = x.permute(0, 2, 1, 3).flatten(2)
        #x = self.proj_out(x)
        x = F.linear(x, weights['layers.' + str(i) + '.attention.proj_out.weight'], weights['layers.' + str(i) + '.attention.proj_out.bias'])

        return x


class FNN(nn.Module):
    '''
    Transformerエンコーダ内の順伝播型ニューラルネットワーク
    dim_hidden     : 入力特徴量の次元
    dim_feedforward: 中間特徴量の次元
    '''
    def __init__(self, dim_hidden: int, dim_feedforward: int):
        super().__init__()

        self.linear1 = nn.Linear(dim_hidden, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, dim_hidden)
        self.activation = nn.GELU()

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)

        return x

class fFNN(nn.Module):
    '''
    Transformerエンコーダ内の順伝播型ニューラルネットワーク
    dim_hidden     : 入力特徴量の次元
    dim_feedforward: 中間特徴量の次元
    '''
    def __init__(self, dim_hidden: int, dim_feedforward: int):
        super().__init__()

        #self.linear1 = nn.Linear(dim_hidden, dim_feedforward)
        #self.linear2 = nn.Linear(dim_feedforward, dim_hidden)
        self.activation = nn.GELU()

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor, i, weights ):
        #x = self.linear1(x)
        x = F.linear(x, weights['layers.' + str(i) + '.fnn.linear1.weight'], weights['layers.' + str(i) + '.fnn.linear1.bias'])
        x = self.activation(x)
        #x = self.linear2(x)
        x = F.linear(x, weights['layers.' + str(i) + '.fnn.linear2.weight'], weights['layers.' + str(i) + '.fnn.linear2.bias'])

        return x


class TransformerEncoderLayer(nn.Module):
    '''
    Transformerエンコーダ層
    dim_hidden     : 入力特徴量の次元
    num_heads      : ヘッド数
    dim_feedforward: 中間特徴量の次元
    '''
    def __init__(self, dim_hidden: int, num_heads: int,
                 dim_feedforward: int):
        super().__init__()

        self.attention = SelfAttention(dim_hidden, num_heads)
        self.fnn = FNN(dim_hidden, dim_feedforward)

        self.norm1 = nn.LayerNorm(dim_hidden)
        self.norm2 = nn.LayerNorm(dim_hidden)
        
        self.dropout = nn.Dropout( 0.1 )

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor, attention_ids: torch.Tensor):
        x0 = x
        x = self.attention(x, attention_ids) 
        x = self.dropout( x )
        x = self.norm1( x0 + x )
        x1 = x
        x = self.fnn(x)
        x = self.dropout( x )
        x = self.norm2( x + x1 )

        return x

class fTransformerEncoderLayer(nn.Module):
    '''
    Transformerエンコーダ層
    dim_hidden     : 入力特徴量の次元
    num_heads      : ヘッド数
    dim_feedforward: 中間特徴量の次元
    '''
    def __init__(self, dim_hidden: int, num_heads: int,
                 dim_feedforward: int):
        super().__init__()

        self.fattention = fSelfAttention(dim_hidden, num_heads)
        self.ffnn = fFNN(dim_hidden, dim_feedforward)
        self.dim_hidden = dim_hidden

        self.dropout = nn.Dropout( 0.1 )


    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor, attention_ids: torch.Tensor, i, weights ):
        x0 = x
        x = self.fattention(x, attention_ids, i, weights ) 
        x = self.dropout( x )
        x = F.layer_norm(x0 + x, (self.dim_hidden,), weight=weights['layers.' + str(i) + '.norm1.weight'], bias=weights['layers.' + str(i) + '.norm1.bias'], eps=1e-05)
        x1 = x
        x = self.ffnn(x, i, weights)
        x = self.dropout( x )
        x = F.layer_norm(x + x1, (self.dim_hidden,), weight=weights['layers.' + str(i) + '.norm2.weight'], bias=weights['layers.' + str(i) + '.norm2.bias'], eps=1e-05)


        return x

class MAML(nn.Module):
        
    def __init__(self):
        super(MAML, self).__init__()
        num_class = 2
        num_heads = 4
        num_layers = 6
        self.num_layers = num_layers
        dim_hidden = 256
        self.dim_hidden = dim_hidden
        dim_feedforward = 384
        max_seq =128
        self.max_seq = max_seq
        self.embed = nn.Embedding( 30522, dim_hidden )
        self.pe = PositionalEncoding( dim_hidden )
        self.layers = nn.ModuleList([TransformerEncoderLayer(
            dim_hidden, num_heads, dim_feedforward
        ) for _ in range(num_layers)])
        self.ftrenc = fTransformerEncoderLayer( dim_hidden, num_heads, dim_feedforward )

        self.logits = nn.Linear( dim_hidden, num_class )
   

    def forward(self, x, attention_ids):
        
        x = self.embed( x )
        x = self.pe( x )

        # Transformerエンコーダ層を適用
        for layer in self.layers:
            x = layer(x, attention_ids)
            #print( "layer x:", x )
      
        #x = x.view( x.size(0), -1 )
     
        x = x[:,0,:]
     
        return self.logits(x)

    def adaptation(self, x, attention_ids, weights):
        x = F.embedding(x, weights['embed.weight'] )
        x = self.pe( x )
        for block in range( self.num_layers ):
            x = self.ftrenc(x, attention_ids, block, weights )
        
        #x = x.view( x.size(0), -1 )
        
        x = x[:,0,:]
        
        return F.linear(x, weights['logits.weight'], weights['logits.bias'])

model.forward と model.adaptation を試すためのプログラム

import torch
import random
import numpy as np
from maml import MAML
from collections import OrderedDict

model = MAML()
    
weights = OrderedDict(model.named_parameters()) #今回の基準パラメータ
input = torch.randint( 0, 10000, size=(10, 128 ) )
input_mask = torch.randint( 0, 1 , size=(10,128 ) ).bool()

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
model.train()
#model.eval()
output1 = model( input, input_mask )
print( "output of model.forward:", output1[0][0] )

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
model.train()
#model.eval()
output2 = model.adaptation( input, input_mask, weights )
print( "output of model.adaptation:", output2[0][0] )
output of model.forward: tensor(-0.3522, grad_fn=<SelectBackward0>)
output of model.adaptation: tensor(-0.3522, grad_fn=<SelectBackward0>)

データは、CIFAR100 から自前で学習用データセットを作りました。画像データは、[ outer_batch, num_task, N-way * k-shot, 3, 32, 32 ] の形で、ラベルデータは、[ outer_batch, num_task, N-way * k-shot ] の形で作るようにしました。 [ 3, 32, 32 ] は画像の次元です。N-way は5で、k-shot は、サポートデータについては10、クエリーデータについては15です。

学習結果

学習は 100 エポックで、train acc = 0.98, val acc = 0.61, test acc = 0.63 でした。trainフェーズに比較して、validation フェーズと test フェーズの accuracy が悪いのはあとで説明します。学習にかかった時間は、RTX 6000 一枚で 1時間程度、CPU でも3時間程度でした。

プログラムの要点

MAML の学習プログラムでは、タスクが重要な役割を果たします。画像認識では、CIFAR100 のデータで、0~99 のクラスについて、0番目のタスク 0~4, 2番目のタスク 5~9,・・・,19番目のタスク 95~99 と20個のタスクを考えました。20個のタスクのうち、0〜16をtrain フェーズで用い、validation フェーズと test フェーズは17、18、19のタスクを用いました。その結果 train フェースと validation, test フェーズの acc に差が出たと考えられます。

あるタスクについて、データは support データと query データを作成します。加えて、勾配は、inner と outer と二種類計算します。

あるタスクについて、最初は、初期値の学習パラメータで support データについて loss を計算し inner の勾配を求めます。この時 torch.autograd.grad 関数を使い、create_graph = True として、outer フェーズに graph を渡す準備をします。この勾配から inner_lr を使って途中の学習パラメータを求めます。train_step = 2 回目の学習パラメータ―は、初期値(train_step=1)の学習パラメータ―と inner_lr と train_step = 1 で求めた勾配の積の差から求まります。勾配降下法。train_step=3 の学習パラメータは、train_step=2の学習パラメータとこのパラメータとサポートデータから得た loss を使って求めた勾配から勾配降下法を使って求めます。これを train_step = 5 回行い、勾配降下法を積み上げて学習パラメーターを求めます。次は outer フェーズです。

一つのタスクについて、inner の train_stepを train_step(=5)回行い学習パラメーターを求めて、求まった学習パラメーターを query データのモデルに適用して outer の loss を求めます。この outer の loss を用いて、全体の勾配を求めます。この時の勾配計算で torch.autograd.grad を用いて、勾配は、最初の初期値パラメータ( train_step = 1 の時のパラメーター)について求めます。この勾配計算をすべてのタスクについて行い、全体の勾配の和を求めます。

一般的には、outer の loss の和を求めてから、loss の和に対して勾配計算を行うようですが、メモリーを多く使うので、メモリー節約の観点から勾配の和を求めました。 loss に対して、勾配計算と和は可換だと思います。ですから、得られる計算結果である更新された学習パラメータは loss の和を求めてから勾配を計算しても、勾配の和を計算しても同じだと思います。

この全体の勾配の和を学習パラメータ param.grad に代入し、step 関数でモデルの学習パラメータ―を更新します。

MAML でも、model パラメータを求めるのですが、validation や test を行うとき、保存した model パラメータで計算した loss と acc ではなく、保存したモデルパラメータについて、興味のある task の inner フェーズで求めた学習パラメーターを用いて、model 計算を行うようです。

def adaptation(model, outer_optimizer, batch, loss_fn, train_step, train, lr1,  device):

    x_train = batch[0] #support 画像データ
    y_train = batch[1] #support ラベルデータ
    x_val = batch[2]   #query 画像データ
    y_val = batch[3]   #query ラベルデータ

    task_accs = []
    num_task = len( x_train )

    outer_loss_item = 0

    if train:
        weights0 = OrderedDict(model.named_parameters()) #今回の基準パラメータ
    for idx in range(x_train.size(0)): # task
        if train:
            weights = weights0
        weights2 = OrderedDict(model.named_parameters())
        # batch 抽出
        input_x = x_train[idx].to(device)
        input_y = y_train[idx].to(device)
        x = input_x
        y = input_y
        
        print('----Task',idx, '----')

        # タスクごとの損失の計算
        loss_item = 0
        for iter in range(train_step): # train_step のループ
            model.train()
            logits = model.adaptation(x, weights2 )
            loss = loss_fn(logits, y)
            loss_item += loss.item()
            #各タスクについて一番目の損失関数からモデルパラメーターを求める。
            #graph を残して、2回めの更新のときにその情報を使う。
            #gradients = torch.autograd.grad(loss, weights2.values())
            gradients = torch.autograd.grad(loss, weights2.values(), create_graph=True)
            weights2 = OrderedDict((name, param - lr1 * grad) for ((name, param), grad) in zip(weights2.items(), gradients))

        loss_item = loss_item / train_step 

        print("Inner Loss: ", loss.item())
        
        # query データからバッチ抽出
        input_x = x_val[idx].to(device)
        input_y = y_val[idx].to(device)
        
        # 訓練時に query データ( query_k * 5クラス ) で全体の勾配の総和を求める。
        x = input_x
        y = input_y
        # 各タスクについて、上で求めたモデルパラメーターを使って損失を求める。
        if train:
            model.train()
            logits = model.adaptation( x, weights2 )
        else:
            model.eval()
            with torch.no_grad():
                logits = model.adaptation( x, weights2 )
            
        outer_loss0 = loss_fn( logits, y )
        outer_loss_item += outer_loss0.item()
        if train:
            tmp = torch.autograd.grad( outer_loss0, weights.values() )
            if idx ==0:
                gradients2 = list(tmp)
            else:
                gradients2 = [x + y for x, y in zip(gradients2, list(tmp))]
        pre_label_id = torch.argmax( logits, dim = 1 )
        acc = torch.sum( torch.eq( pre_label_id, y ).float() ) / y.size(0)
        task_accs.append(acc)


    # 訓練時、モデルパラメーターを更新する。
    if train:
        outer_optimizer.zero_grad()
        for i, params in enumerate(model.parameters()):
            params.grad = gradients2[i]
        outer_optimizer.step()

    task_accs = torch.stack( task_accs )
    outer_loss_item = outer_loss_item / num_task

    print( "loss:", outer_loss_item )

    return outer_loss_item, torch.mean(task_accs).item()

adaptation 関数は、一見、loss と accuracy を返すだけのように見えるが、model.parametrs() の更新が一番大きな役割です。

実装したプログラム

わたくしもそうだったのですが、ページであれこれ説明されても良く分からなかったです。分かるためには、実装したプログラムを使ってみて、自分が正しいと思うように修正すると分かったと思えました。理解したい方は、github のプログラムをダウンロードして使ってみて、自分なりに直してみてください。よろしくお願いいたします。

感情分析の MAML メタ学習プログラムも置いておきます。タスクは、感情分析する文章の話題(domain) です。Transformer Encoder を使っています。<CLS>トークンは使わずに、テキストを tokeinze したあとの sequence の最大値を 128 として、Transformer Encoder の出力を [ batch, 128, 256] と固定しました。その結果、最終的な分類のための線形層を nn.Linear( 128*256, 2 )とすることができました。精度は 300エポックで train acc = 0.79, val acc = 0.74, test acc = 0.77 でした。学習にかかった時間は、300 epochs で RTX 6000 一枚で2時間程度でした。CPUでも4時間程度です。

このページを参考にしました。

感情分析について CLS トークンを用いたメタ学習も行ってみました。100エポックで、train acc = 0.65, val acc = 0.60, test acc = 0.58 でした。ソースを置いておきます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?