0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

「Python で学ぶ画像認識」の Transformer を用いた画像キャプショニングのプログラムを非自己回帰型に改修しました。

Last updated at Posted at 2024-02-27

はじめに

これまで、音声合成、音声認識、機械翻訳において、推論を非自己回帰的に行う Transformer の使い方を提案してきました。今回は、画像キャプショニングのプログラムの推論を非自己回帰的に行うように改修しましたので、ご報告させていただきます。

提案する Transformerの 使い方。

Transformer デコーダーには、q, k, v の入力が必要です。デコーダーの Self Attention では、一般に、q,k,vの入力には、すべて、同じ q を入力します。また、 Cross Attention では、一般に、q には クエリー q を、k,v には、キーk を入力します。学習時には、通常の Transformer デコーダーの Self Attention への3つの入力には、教師データの time sequence で最後を削ったものを label に用います。また、Cross Attention への q 入力は、Self Attention の出力であり、k, v 入力は Transformer Encoder の出力です。この削った label を入力に用いることは、自己回帰型の推論をするという前提で使われています。推論を非自己回帰的に行うためには、この削った label を Transofrmer Decoder の入力に使えないことを意味します。それでは、どのような計算量が妥当なのでしょうか。提案する手法では、削った label の代わりに Transformer Encoder の出力を time sequence 方向にダウンサンプリングした計算量を用います。これは、学習時に損失を CTCLoss で計算すること、および、推論時に、モデルによる出力を CTC によりデコードして推論結果を得ることを前提に、モデル出力の time sequence への縛りをなくすことにより可能です。

ダウンサンプリングについて。

「Python で学ぶ画像認識」という本の p.339 に掲載されている CaptioningTransformer クラスを次のように改修しました。CaptioningTransformer は、TransformerDecoderLayer を層の厚さの回数だけ呼び出し、一般の TransformerDecoder の役割を果たしていると考えられます。

class CaptioningTransformer(nn.Module):
    '''
    CaptioningTransformerのコンストラクタ
    dim_embedding  : 埋め込み次元
    dim_feedforward: FNNの中間特徴次元
    num_heads      : マルチヘッドアテンションのヘッド数
    num_layers     : Transformerデコーダ層の数
    vocab_size     : 辞書の次元
    null_index     : NULLのID
    dropout        : ドロップアウト確率
    '''
    def __init__(self, dim_embedding: int, dim_feedforward: int,
                 num_heads: int, num_layers: int, vocab_size: int,
                 null_index: int, dropout: float=0.5, ds_rate: float=0.1):
        super().__init__()

        # 単語埋め込み
        #self.embed = nn.Embedding(
        #    vocab_size, dim_embedding, padding_idx=null_index)
        
        # 位置エンコーディング
        self.positional_encoding = PositionalEncoding(dim_embedding)

        # Transformerデコーダ
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(
                dim_embedding, num_heads, dim_feedforward, dropout)
            for _ in range(num_layers)
        ])

        # 単語出力分布計算
        self.linear = nn.Linear(dim_embedding, vocab_size)

        # パラメータ初期化
        self._reset_parameters()
        
        self.ds_rate = ds_rate
    

    '''
    パラメータの初期化関数
    '''
    def _reset_parameters(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):

                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                nn.init.zeros_(module.bias)
                nn.init.ones_(module.weight)

    def downsample(self, enc_out, input_lengths ):

        max_label_length = int( round( enc_out.size(1) * self.ds_rate ) )

        polated_lengths = torch.round( torch.ones( enc_out.size(0) ) * enc_out.size(1) * self.ds_rate ).long()

        outputs_lens = torch.round( input_lengths * self.ds_rate ).long()
        #print( "output_lens", outputs_lens)

        x = enc_out
        out_lens = polated_lengths

        y = torch.tensor( [] )
        
        for i in range( x.size(0) ):
            x0 = torch.unsqueeze( x[i], dim = 0 )
            x0 = x0.permute( 0,2,1)
            x_out = torch.nn.functional.interpolate(x0, size = (out_lens[i]), mode='nearest-exact')
            z = torch.zeros( x_out.size(0), x_out.size(1), max_label_length )
            if z.size(2) > x_out.size(2):
                z[:,:,:x_out.size(2)] = x_out[:,:,:]
            else:
                z[:,:,:] = x_out[:,:,:z.size(2)]
            x_out = z.permute( 0, 2, 1 )
            y = torch.cat( [y, x_out], dim = 0)
            
        return y, outputs_lens

    ''' CaptioningTransformerの順伝播処理
    features: 画像特徴量 [バッチサイズ, 埋め込み次元]
    captions: 正解キャプション [バッチサイズ, 系列長]

    '''
    def forward(self, features: torch.Tensor):
        config = ConfigTrain()
        feature_lengths = torch.ones( (features.size(0) ) ) * features.size(1)

        # 単語埋め込み [バッチサイズ, 系列長]
        # -> [バッチサイズ, 系列長, 埋め込み次元]
        #embeddings = self.embed(captions)
        
        embeddings, outputs_lengths = self.downsample(features, feature_lengths)
        
        seq = embeddings.shape[1]
        
        # 位置エンコーディング
        embeddings = self.positional_encoding(embeddings.to(device=features.device))

        #features = features.unsqueeze(1)

        # 未来のキャプションを参照しないようにマスク行列を生成
        #tgt_mask = torch.tril(features.new_ones((seq, seq)))
        #tgt_mask = tgt_mask == 0

        # Transformerデコーダでキャプション生成
        # 画像の特徴も入力する
        for layer in self.decoder_layers:
            embeddings = layer(embeddings, features, tgt_mask = None)

        # [バッチサイズ, 系列長, 埋め込み次元]
        # -> [バッチサイズ, 系列長, 辞書の次元]
        preds = self.linear(embeddings)

        return preds, outputs_lengths


改修の第一は、順伝搬処理 foward を呼び出すときに、教師データをずらした captions_in がいらないことです。Transformer Decoder の入力は、encoder = CNNEncoder ( Transformer Encoder ) の出力と encoder の出力をダウンサンプリングした計算量なので、呼び出しには、encoder(imgs) の出力 features だけで良いです。あとは、features の time sequence の長さを算出して、self.ds_rate で指定した割合で features をダウサンプリングします。その値 embeddings と、embeddingsの time sequence の長さ outputs_lengths を得ます。 順伝播処理 forward は、embedings を Transformer Decoder Layers に入力して、その出力と、 time sequence の長さ outputs_lengths を返せば良いです。また、本に掲載されている自己回帰型の関数 sample はいりません。モデル出力は推論時も順伝搬処理 forward で良いです。その代わり、順伝搬処理の出力を CTC を前提にデコードしないければならないです。その関数を注意書きのあとに示します。

ここで注意書きです。downsample 関数に入力する特徴量 features の time sequence の長さ feature_lengths ですが、今は画像の特徴量なので、バッチにおけるすべての uterance につて同じ長さなので今回のような扱いができました。しかし、扱っている特徴量が言語などの場合、バッチの各 uterance について、特徴量の長さが異なる場合があります。その場合は、特徴量をファイルから読み込んだ時など、最初から特徴量の time seqence 長に気を使い、これを Transformer Decoder の順伝播処理 forward 関数に入力するべきです。downsample 関数の中の outputs_lengths の計算においては、バッチの各 uterance ごとの長さを計算するようになっています。time sequence の長さは、CTCLoss に入力する必要があるので、必ず計算しなければなりません。

def ctc_simple_decode(int_vector, token_list):
    ''' 以下の手順で,フレーム単位のCTC出力をトークン列に変換する
        1. 同じ文字が連続して出現する場合は削除
        2. blank を削除
    int_vector: フレーム単位のCTC出力(整数値列)
    token_list: トークンリスト
    output:     トークン列
    '''
    # 出力文字列
    output = []
    # 一つ前フレームの文字番号
    prev_token = -1
    # フレーム毎の出力文字系列を前から順番にチェックしていく
    for n in int_vector:
        n = n.item()
        if n != prev_token:
            # 1. 前フレームと同じトークンではない
            if n != 0:
                # 2. かつ,blank(番号=0)ではない
                # --> token_listから対応する文字を抽出し,
                #     出力文字列に加える
                output.append( token_list[n])
                if token_list[n] == '<end>':
                    break
            # 前フレームのトークンを更新
            prev_token = n

加えて、学習時には、CrossEntropyLossではなく、CTCLoss を計算します。

criterion = nn.CTCLoss(blank=0, reduction='mean',zero_infinity=False) 


outputs = F.log_softmax( outputs, dim=2 )
loss = criterion(outputs.transpose(0, 1),captions,outputs_lengths,caption_lengths)

CTCLoss を使うために、word_to_id と id_to_word の id を 1だけ後ろにずらしました。0 は、 <blank> に使うため、開けておかなければなりません。Pytorch の CTCLoss には、nn.CTCLoss と torch.nn.functional.ctc_loss があるようですが、今回は nn.CTCLoss を使っています。nn.CTCLoss は logits を入力することができず、log_softmax の出力を入力しています。

あと、自己回帰型の CaptioningTransformer( TransformerDeoder ) に必要な、先読みマスク( causal mask )は、必要ありません。

これらが、改修の要点です。

その他の改修点。

本のプログラムでは、encoder = CNNEncoder の出力は、batch_size × dim_embedding であしたが、これを、次のような CNNEncoder2 に改修しました。

import torch
from torch import nn
from torchvision import models

class CNNEncoder2(nn.Module):
    '''
    Show and tellのエンコーダ
    dim_embedding: 埋め込み次元
    '''
    def __init__(self, dim_embedding: int):
    #def __init__(self):
        super().__init__()

        # ImageNetで事前学習された
        # ResNet152モデルをバックボーンネットワークとする
        resnet = models.resnet152(weights="IMAGENET1K_V2")
        modules = list(resnet.children())[:-4]
        self.backbone = nn.Sequential(*modules)

        # デコーダへの出力
        #self.linear = nn.Linear(resnet.fc.in_features, dim_embedding)
        in_features = torch.tensor( ( 224 / 8 ) ** 2 ).to( torch.int16 )
        self.linear = nn.Linear( in_features, dim_embedding)

        
    '''
    エンコーダの順伝播
    imgs: 入力画像, [バッチサイズ, チャネル数, 高さ, 幅]
    '''
    def forward(self, imgs: torch.Tensor):
        # 特徴抽出 -> [バッチサイズ, 512, 28×28→dim_embbeding]
        # 今回はバックボーンネットワークは学習させない
        with torch.no_grad():
            features = self.backbone(imgs)
            features = features.flatten(2)
            
        # 全結合
        features = self.linear(features)

        return features

出力は、resnet152 の -4 の位置の出力 batch_size × 512 × 28 × 28(入力が batch_size × 3 × 224 × 224 の時) を batch_size × 512 × dim_embedding に整形しました。3階のテンソルなので、 Transformer で扱いやすくなりました。ちなみに、time sequence の長さが 512 で、教師データの time sequence の長さが 18 のようなので、self.ds_rate = 0.1 としました。

学習曲線と学習結果

学習曲線

損失のグラフを train と validation について掲載させていただきます。

fig1.png

fig2.png

学習結果

まだ、開発途上なので精度はあまりよくありませんが、推論した結果を掲載させていただきます。

fig3.png
fig4.png
fig5.png
fig6.png
fig7.png
fig8.png
fig9.png
fig10.png
fig11.png
fig12.png
fig13.png
fig14.png
fig15.png
fig16.png
fig17.png
fig18.png
fig19.png
fig20.png
fig21.png

参考のため、使ったプログラムを Github にアップしておきます。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?