前提
概要
FixMatchは教師のあるサンプルと教師のないサンプルを用いて学習を行う「半教師あり学習」を行う手法の1つです。FixMatchの公式実装はTensorFlow
で構築されていますが、TensorFlow
はバージョンの再現が大変なので、気軽に動かすにはPyTorch
実装の方が使いやすいです。
当記事では下記のFixMatchのPyTorch
実装について確認し、FixMatchを理解しながらPyTorch
の機能について合わせて抑えます。
FixMatchのパフォーマンス
FixMatchのパフォーマンスについては主にCIFAR10やCIFAR100によって測定されます。
・CIFAR10
#Labels | 40 | 250 | 4000 |
---|---|---|---|
Paper (RA) | 86.19 ± 3.37 | 94.93 ± 0.65 | 95.74 ± 0.05 |
This code | 93.60 | 95.31 | 95.77 |
Acc. curve | link | link | link |
FixMatchのCIFAR10でのパフォーマンス(Top-1 acc)は上記より確認できます。CIFAR10のタスクについては下記のような画像を分類させると理解しておけば良いです。
・CIFAR100
#Labels | 400 | 2500 | 10000 |
---|---|---|---|
Paper (RA) | 51.15 ± 1.75 | 71.71 ± 0.11 | 77.40 ± 0.12 |
This code | 57.50 | 72.93 | 78.12 |
Acc. curve | link | link | link |
FixMatchのCIFAR10でのパフォーマンス(Top-1 acc)は上記より確認できます。CIFAR100の分類カテゴリは下記より確認できます。
詳細
train.py
前節で確認したリポジトリでは学習の実行にあたって下記のようなtrain.py
を動かします。
$ python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5
実行結果のmodelは--out
で指定したファイルに保存されます。その他の実行時に与える引数は--dataset
ではcifar10
やcifar100
、--num-labeled
では前節の表に記載の数字(CIFAR10の場合は40・250・4000のどれか)、--batch-size
ではバッチに含まれるサンプルの数をそれぞれ指定します。
DataLoader周りの実装
FixMatchの実装を行う上で特徴的なのがDataLoader周りの実装です。
FixMatchのDiagram(FixMatch論文 Fig.1)
FixMatchでは上図で表されるようにweak augmentationとstrong augmentationの2種類のData Augmentation(データ拡張)が用いられます。weak augmentationにはflip(反転)やshift(平行移動)が用いられる一方で、strong augmentationにはReMixMatch論文で提案されたCTAugment(Control Theory Augment)やRandAugmentが用いられます。
以下、strong augmentationの実装について確認します。
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from dataset.cifar import DATASET_GETTERS
labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](
args, './data')
train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
labeled_trainloader = DataLoader(
labeled_dataset,
sampler=train_sampler(labeled_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers,
drop_last=True)
unlabeled_trainloader = DataLoader(
unlabeled_dataset,
sampler=train_sampler(unlabeled_dataset),
batch_size=args.batch_size*args.mu,
num_workers=args.num_workers,
drop_last=True)
test_loader = DataLoader(
test_dataset,
sampler=SequentialSampler(test_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers)
上記にtrain.py
におけるDataLoader
周りの実装を抽出しました。labeled_loader
がラベル付きのデータセット、unlabeled_loader
がラベルなしのデータセット、test_loader
がテストデータのデータセットにそれぞれ対応するDataLoader
です。まずDataLoader
の引数については、labeled_loader
とunlabeled_loader
ではsampler
にRandomSampler
が基本的に用いられるかつdrop_last=True
が指定される一方で、test_loader
ではsampler
にSequentialSampler
が用いられるかつdrop_last
が指定されない(デフォルトはFalse
)ことが確認できます。
また、それぞれのDataLoader
の第一引数に与えるデータセットはdataset/cifar.py
から読み込んだDATASET_GETTERS
を用います。dataset/cifar.py
では下記のように実装されています。
from PIL import Image
from torchvision import datasets
def get_cifar10(args, root):
transform_labeled = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(size=32,
padding=int(32*0.125),
padding_mode='reflect'),
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])
base_dataset = datasets.CIFAR10(root, train=True, download=True)
train_labeled_idxs, train_unlabeled_idxs = x_u_split(
args, base_dataset.targets)
train_labeled_dataset = CIFAR10SSL(
root, train_labeled_idxs, train=True,
transform=transform_labeled)
train_unlabeled_dataset = CIFAR10SSL(
root, train_unlabeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
test_dataset = datasets.CIFAR10(
root, train=False, transform=transform_val, download=False)
return train_labeled_dataset, train_unlabeled_dataset, test_dataset
class CIFAR10SSL(datasets.CIFAR10):
def __init__(self, root, indexs, train=True,
transform=None, target_transform=None,
download=False):
super().__init__(root, train=train,
transform=transform,
target_transform=target_transform,
download=download)
if indexs is not None:
self.data = self.data[indexs]
self.targets = np.array(self.targets)[indexs]
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
DATASET_GETTERS = {'cifar10': get_cifar10,
'cifar100': get_cifar100}
上記のget_cifar10
関数ではtrain_labeled_dataset
、train_unlabeled_dataset
、test_dataset
を返すことが確認できます。train_labeled_dataset
の取得にあたってはCIFAR10SSL
クラスが用いられ、引数にtransform=transform_labeled
などが与えられます。train_unlabeled_dataset
では同様にCIFAR10SSL
クラスが用いられる一方で、引数にtransform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std)
が与えられていることに着目しておくと良いです。また、test_dataset
の取得にあたってはtorchvision.datasets.CIFAR10
が用いられます。
CIFAR10SSL
の実装にあたっては、torchvision.datasets.CIFAR10
をベースに__getitem__
でバッチの取得について記載されています。__getitem__
内部ではimg = Image.fromarray(img)
のようにPillow(PIL)の形式に変換されることも確認できます。
また、train_unlabeled_dataset
の構築の際に引数に与えられるTransformFixMatch
については下記から確認できます。
from torchvision import transforms
from .randaugment import RandAugmentMC
class TransformFixMatch(object):
def __init__(self, mean, std):
self.weak = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(size=32,
padding=int(32*0.125),
padding_mode='reflect')])
self.strong = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(size=32,
padding=int(32*0.125),
padding_mode='reflect'),
RandAugmentMC(n=2, m=10)])
self.normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)])
def __call__(self, x):
weak = self.weak(x)
strong = self.strong(x)
return self.normalize(weak), self.normalize(strong)
上記の確認にあたっては__init__
関数におけるself.strong
への代入にあたってRandAugmentMC(n=2, m=10)
が用いられていることが確認できます。
RandAugmentの実装
dataset/randaugment.py
では下記のようにRandAugmentMC
クラスが実装されています。
def fixmatch_augment_pool():
# FixMatch paper
augs = [(AutoContrast, None, None),
(Brightness, 0.9, 0.05),
(Color, 0.9, 0.05),
(Contrast, 0.9, 0.05),
(Equalize, None, None),
(Identity, None, None),
(Posterize, 4, 4),
(Rotate, 30, 0),
(Sharpness, 0.9, 0.05),
(ShearX, 0.3, 0),
(ShearY, 0.3, 0),
(Solarize, 256, 0),
(TranslateX, 0.3, 0),
(TranslateY, 0.3, 0)]
return augs
class RandAugmentMC(object):
def __init__(self, n, m):
assert n >= 1
assert 1 <= m <= 10
self.n = n
self.m = m
self.augment_pool = fixmatch_augment_pool()
def __call__(self, img):
ops = random.choices(self.augment_pool, k=self.n)
for op, max_v, bias in ops:
v = np.random.randint(1, self.m)
if random.random() < 0.5:
img = op(img, v=v, max_v=max_v, bias=bias)
img = CutoutAbs(img, int(32*0.5))
return img
ops = random.choices(self.augment_pool, k=self.n)
によってself.augment_pool
で定義された14のAugmentationの手法から処理をサンプリングし、Augmentation処理が行われます。14種類の各処理は下記のように実装されています。
import PIL
def AutoContrast(img, **kwarg):
return PIL.ImageOps.autocontrast(img)
def Brightness(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Color(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Color(img).enhance(v)
def Contrast(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Equalize(img, **kwarg):
return PIL.ImageOps.equalize(img)
def Identity(img, **kwarg):
return img
def Posterize(img, v, max_v, bias=0):
v = _int_parameter(v, max_v) + bias
return PIL.ImageOps.posterize(img, v)
def Rotate(img, v, max_v, bias=0):
v = _int_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
return img.rotate(v)
def Sharpness(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def ShearX(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
def ShearY(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
def Solarize(img, v, max_v, bias=0):
v = _int_parameter(v, max_v) + bias
return PIL.ImageOps.solarize(img, 256 - v)
def TranslateX(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
v = int(v * img.size[0])
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateY(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
v = int(v * img.size[1])
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
当記事で参照した実装ではPillow(PIL)がAugment処理に用いられていることは着目しておくと良いと思います。Pillowについては下記で詳しく取り扱いましたので合わせてご確認ください。