0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SwinTransformerV2 に padding mask を入力できるようにする試み。

Last updated at Posted at 2025-07-01

改修を試みた動機

今までに、SwinTransformer を使うとき、画像の Transforms.Resize は、正方形の 224, 224 の形に Resize していた。しかし、ImageNet の元の画像などは、正方形とは限らない。縦横比が 1 ではない画像が普通だと思います。一方、SwinTransformer の tiny は 224,224 の正方形画像を扱うようにできているという印象を持っています。縦横比が 1 でない画像を、縦横比を変えずに Resize して、どうやって SwinTransformer で扱ったら良いか考えました。

考え付いたのは、例えば、448×384 の画像は、224 × 192 の画像に縮小します。しかし、image size は 224 × 224 にしておいて、Width の方の 224 - 192 の部分は、0 のピクセルとして、padding mask を True にするという案です。

これにより、縦横比を変えずに、SwinTransformer で、Image Classification ができるのではないかと考えました。そのためには、SwinTransformer で、padding mask を扱えるように改修しなければなりませんでした。

Transforms.Resize と collate function の準備。

準備した Resize クラス

class Resize2:
    '''
    画像をアスペクト比を保持してリサイズするクラス
    Resize した結果 image の長辺の長さ
    '''
    def __init__(self, result_size: int):
        self.result_size = result_size

    '''
    元の画像を長辺が self.result_size になるように画像をリサイズする関数
    img   : リサイズする画像
    '''
    def __call__(self, img: Image):
        # width が長辺か、height が長辺か調べる。
        width, height = img.size
        max_size = max( width, height )
        w_flag = False
        h_flag = False
        if max_size == width:
            w_flag = True
        else:
            h_flag = True
        
        # resize する weidth と height を決定する
        if w_flag == True:
           r_ratio = self.result_size / width
           resized_width = self.result_size
           resized_height = int( height * r_ratio )  
        else:
           r_ratio = self.result_size / height
           resized_height = self.result_size
           resized_width = int( width * r_ratio )


        # 指定した大きさに画像をリサイズ
        img = F.resize(img, (resized_height, resized_width))

        return img

Resize2 クラスのコンストラクト時には、出力する画像の長編のピクセル数を指定します。今の場合 224 です。縦横比が 1 でない画像については、出力画像は 224 × 224 とはなりません。

224 × 224 の画像と padding mask を出力する collate_func5 関数

def collate_func5(batch: Sequence[Tuple[Union[torch.Tensor, str]]]):
    imgs, targets = zip(*batch)

    # 与えられた image の長辺の長さに画像サイズ max_height, max_width を合わせる
    result_size = max( imgs[0].shape[1], imgs[0].shape[2] )
    max_height = result_size
    max_width = result_size
    for img in imgs:
        height, width = img.shape[1:]
        max_height = max(max_height, height)
        max_width = max(max_width, width)

    # (batch数、channel=3, max_height, max_width) で全ピクセル 0 の画像を確保。
    imgs = batch[0][0].new_zeros(
        (len(batch), 3, max_height, max_width))
    # (batch数、max_height, max_width) で全値 True のブールでマスクを確保
    masks = batch[0][0].new_ones(
        (len(batch), max_height, max_width), dtype=torch.bool)
    targets = []
    for i, (img, target) in enumerate(batch):
        #実際のimage を画像リストに代入
        height, width = img.shape[1:]
        imgs[i, :, :height, :width] = img
        # マスクの画像領域には偽の値を設定
        masks[i, :height, :width] = False
        # target ラベルを追加
        targets.append(target)
    
    # target ラベルを torch.tensor 化。
    targets = torch.tensor( targets )

    return imgs, masks, targets

B × 3 × 224 × 224 の画像と、B × 224 × 224 のマスクと、target ラベルを出力します。

SwinTransformerV2 の改修箇所。

〇入力を imgs と padding_mask にして、 PatchEmbed と BasicLayer を padding_mask 対応とした。

〇class PatchEmbed
 x の次元数の変形に伴い、padding_mask も F.interpolat 関数で変形した。

〇class BasicLayer
 blk 関数と downsample(PatchMerging) を padding_mask 対応とした。

〇blk 関数対応
 blk 関数の中身である SwinTransformerBlock を padding_mask 対応とした。

〇 downsample 対応
 x が downsample で変形されているので、x のシーケンス長に合わせて、padding_mask を F.interpolate した。

〇SwinTransformerBloc 対応
 x だけでなく、padding_mask にも shift と window_partition を行い、self.attn 関数に入力した。

〇self.attn 関数(window_attention) 対応
 padding mask に shift と window_partition を行ったものを view 関数で attn の変形に合わせて(attn_mask に合わせて)、ブール値をfloat に変換して -1e9 を乗算し、attn に加えて、softmax を計算した。

精度の測定について。

データは、ImageNet の train データと validation データの 1/100 を使いました。batch_size は 8。 epoch は 25 です。len( train_loader )= 1602、len( val_loader ) = 63 です。

測定結果

loss と accuracy について、Trainデータと Validation データのグラフを掲載する。for comparing のグラフは、padding_mask なしのデータです。for comparing のないグラフは padding_mask を考慮した場合のデータです。

loss.png

acc.png

Validation accuracy の最大値は、padding_mask ありが 0.685 で、padding_mask なしが 0.671 でした。

測定に使用したプログラムを github に置いておきます。

よろしくお願いいたします。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?