0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

FixMatchのPyTorch実装の確認

Last updated at Posted at 2025-01-24

前提

概要

FixMatchは教師のあるサンプルと教師のないサンプルを用いて学習を行う「半教師あり学習」を行う手法の1つです。FixMatchの公式実装はTensorFlowで構築されていますが、TensorFlowはバージョンの再現が大変なので、気軽に動かすにはPyTorch実装の方が使いやすいです。

当記事では下記のFixMatchのPyTorch実装について確認し、FixMatchを理解しながらPyTorchの機能について合わせて抑えます。

FixMatchのパフォーマンス

FixMatchのパフォーマンスについては主にCIFAR10やCIFAR100によって測定されます。

・CIFAR10

#Labels 40 250 4000
Paper (RA) 86.19 ± 3.37 94.93 ± 0.65 95.74 ± 0.05
This code 93.60 95.31 95.77
Acc. curve link link link

FixMatchのCIFAR10でのパフォーマンス(Top-1 acc)は上記より確認できます。CIFAR10のタスクについては下記のような画像を分類させると理解しておけば良いです。

FIxMatch1.png

・CIFAR100

#Labels 400 2500 10000
Paper (RA) 51.15 ± 1.75 71.71 ± 0.11 77.40 ± 0.12
This code 57.50 72.93 78.12
Acc. curve link link link

FixMatchのCIFAR10でのパフォーマンス(Top-1 acc)は上記より確認できます。CIFAR100の分類カテゴリは下記より確認できます。

FIxMatch2.png

詳細

train.py

前節で確認したリポジトリでは学習の実行にあたって下記のようなtrain.pyを動かします。

$ python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5

実行結果のmodelは--outで指定したファイルに保存されます。その他の実行時に与える引数は--datasetではcifar10cifar100--num-labeledでは前節の表に記載の数字(CIFAR10の場合は40・250・4000のどれか)、--batch-sizeではバッチに含まれるサンプルの数をそれぞれ指定します。

DataLoader周りの実装

FixMatchの実装を行う上で特徴的なのがDataLoader周りの実装です。

FixMatch3.png
FixMatchのDiagram(FixMatch論文 Fig.1)

FixMatchでは上図で表されるようにweak augmentationとstrong augmentationの2種類のData Augmentation(データ拡張)が用いられます。weak augmentationにはflip(反転)やshift(平行移動)が用いられる一方で、strong augmentationにはReMixMatch論文で提案されたCTAugment(Control Theory Augment)やRandAugmentが用いられます。

以下、strong augmentationの実装について確認します。

train.py
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from dataset.cifar import DATASET_GETTERS

labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](
        args, './data')

train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

    labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        drop_last=True)

    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        sampler=train_sampler(unlabeled_dataset),
        batch_size=args.batch_size*args.mu,
        num_workers=args.num_workers,
        drop_last=True)

    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers)

上記にtrain.pyにおけるDataLoader周りの実装を抽出しました。labeled_loaderがラベル付きのデータセット、unlabeled_loaderがラベルなしのデータセット、test_loaderがテストデータのデータセットにそれぞれ対応するDataLoaderです。まずDataLoaderの引数については、labeled_loaderunlabeled_loaderではsamplerRandomSamplerが基本的に用いられるかつdrop_last=Trueが指定される一方で、test_loaderではsamplerSequentialSamplerが用いられるかつdrop_lastが指定されない(デフォルトはFalse)ことが確認できます。

また、それぞれのDataLoaderの第一引数に与えるデータセットはdataset/cifar.pyから読み込んだDATASET_GETTERSを用います。dataset/cifar.pyでは下記のように実装されています。

dataset/cifar.py
from PIL import Image
from torchvision import datasets

def get_cifar10(args, root):
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs = x_u_split(
        args, base_dataset.targets)

    train_labeled_dataset = CIFAR10SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled)

    train_unlabeled_dataset = CIFAR10SSL(
        root, train_unlabeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))

    test_dataset = datasets.CIFAR10(
        root, train=False, transform=transform_val, download=False)

    return train_labeled_dataset, train_unlabeled_dataset, test_dataset
    
class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

DATASET_GETTERS = {'cifar10': get_cifar10,
                   'cifar100': get_cifar100}

上記のget_cifar10関数ではtrain_labeled_datasettrain_unlabeled_datasettest_datasetを返すことが確認できます。train_labeled_datasetの取得にあたってはCIFAR10SSLクラスが用いられ、引数にtransform=transform_labeledなどが与えられます。train_unlabeled_datasetでは同様にCIFAR10SSLクラスが用いられる一方で、引数にtransform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std)が与えられていることに着目しておくと良いです。また、test_datasetの取得にあたってはtorchvision.datasets.CIFAR10が用いられます。

CIFAR10SSLの実装にあたっては、torchvision.datasets.CIFAR10をベースに__getitem__でバッチの取得について記載されています。__getitem__内部ではimg = Image.fromarray(img)のようにPillow(PIL)の形式に変換されることも確認できます。

また、train_unlabeled_datasetの構築の際に引数に与えられるTransformFixMatchについては下記から確認できます。

dataset/cifar.py
from torchvision import transforms
from .randaugment import RandAugmentMC

class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)

上記の確認にあたっては__init__関数におけるself.strongへの代入にあたってRandAugmentMC(n=2, m=10)が用いられていることが確認できます。

RandAugmentの実装

dataset/randaugment.pyでは下記のようにRandAugmentMCクラスが実装されています。

dataset/randaugment.py
def fixmatch_augment_pool():
    # FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs
    
class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
        img = CutoutAbs(img, int(32*0.5))
        return img

ops = random.choices(self.augment_pool, k=self.n)によってself.augment_poolで定義された14のAugmentationの手法から処理をサンプリングし、Augmentation処理が行われます。14種類の各処理は下記のように実装されています。

dataset/randaugment.py
import PIL

def AutoContrast(img, **kwarg):
    return PIL.ImageOps.autocontrast(img)


def Brightness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Color(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Color(img).enhance(v)

def Contrast(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Contrast(img).enhance(v)

def Equalize(img, **kwarg):
    return PIL.ImageOps.equalize(img)

def Identity(img, **kwarg):
    return img

def Posterize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.posterize(img, v)

def Rotate(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.rotate(v)

def Sharpness(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    return PIL.ImageEnhance.Sharpness(img).enhance(v)

def ShearX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))

def ShearY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))

def Solarize(img, v, max_v, bias=0):
    v = _int_parameter(v, max_v) + bias
    return PIL.ImageOps.solarize(img, 256 - v)

def TranslateX(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[0])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))

def TranslateY(img, v, max_v, bias=0):
    v = _float_parameter(v, max_v) + bias
    if random.random() < 0.5:
        v = -v
    v = int(v * img.size[1])
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

当記事で参照した実装ではPillow(PIL)がAugment処理に用いられていることは着目しておくと良いと思います。Pillowについては下記で詳しく取り扱いましたので合わせてご確認ください。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?