11
7

More than 5 years have passed since last update.

pytorchで初めてゼロから書くSOTA画像分類器(下)(Shake-Shake Regularization実装)

Last updated at Posted at 2019-03-13

はじめに

前編はこちら:
pytorchで初めてゼロから書くSOTA画像分類器(上)

前回はポンコツモデルにBatchNormやResNet、いろいろ導入してCIFAR-10に対する正解率を73%まで上げました。
今回は躍進して、2017年SOTAだったShake-Shake Regularizationの実装に挑戦します。

この記事はDeep Learningの理論を知っているが、実装をあまりしてこなかった人が対象です。

Shake-Shake Regularization

2017年にGastaldi氏により発表されました。
肝となる技術は一つだけで、且つとても簡単なものでした。
要は、画像分類をするときに、普通はdata augmentationをするでしょう?
で、それが画像のスペースでだけ行われるものであると。
コンピューターにとっては画像スペースとネットワークの中の表現スペース(internal representation)と違いはありません。
そこで、ResNetのbranchごとの重みをいじれば、internal representation spaceででも、data augmentation(みたいなこと)ができるのではないか、という話です。
image.png

2 branches + shortcutのResNetを想定して、forwardの時にbranchごとに$\alpha$と$1-\alpha$の重みを与え、backwardの時に$\beta$と$1-\beta$の重みを与え、テストの時に重みを0.5/0.5にします。これを従来のResNet blockにすり替われば、それでもう完成します。

実装

前編のポンコツ物と比べていろいろ改修していますので、逐一触れていきます。

データセットの準備

if transform_mean.any() == None:
    transform_mean = np.array([0.4914, 0.4822, 0.4465])
if transform_std.any() == None:
    transform_std = np.array([0.2470, 0.2435, 0.2616])

これは予め算出したデータのチャネルごとの平均と分散の値です。データセットを正規化するのに使います。画像データだけでなく、多くの場合はそれをしないと学習がうまく行きません。


train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std)
])

train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), test_transform)

ToTensor()以外なんか増えてきます。
Normalizeは上記の正規化をする関数。
RandomCropとRandomHorizontalFlipはdata augmentationをしています。データは常に有限であり、真の分布を表すにはデータ数が足りないことが多い(スパースである)ため、ネットワークが真の分布をうまく掴めないことが多いです。Data augmentation(増幅)とはデータに対し、画像を反転するなり、その中の一部だけと切り取るなり、データセットの変化を多様にする作業です。

data loaderの作成


train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=128, 
                                           shuffle=True, 
                                           num_workers=8,
                                           pin_memory=True,
                                          drop_last=True
                                          )

test_loader = torch.utils.data.DataLoader(test_dataset, 
                                           batch_size=128, 
                                           shuffle=False, 
                                           num_workers=8,
                                           pin_memory=True,
                                           drop_last=False
                                         )

train_dataset_size = len(train_dataset)
test_dataset_size = len(test_dataset)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

最後にはbatch sizeやほかのパラメータを決めていきます。
shuffleとはdataloaderを作成する際にshuffleするかを決めます。訓練のときはシャッフルするようにします。
num_workerとpin_memoryは読み込み速度に関連しており、結果に影響はないはずです。
drop_lastはデータ数をbatch sizeで割り算して、もし整数で割り切れない場合は、最後のバッチを捨てることです。

本体

ここが難関です。初めての人がここでどうやって始めるものかと思ってしまうかもしれませんね。
前編で解説した通り、まずは中の小さい部品を予め定義して、後で流用すればいいと思います。
ただ、構造が複雑になっている分、同じ要領で更なる分解することが必要となります。
経験者には必要のないことですが、全体の構図を把握するために、私のしたことは紙に全部書くことです。
こんな感じに:
image.png
層ごとに、入力/出力チャネル数、カーネルサイズ、ストライド、パッディング、feature mapのサイズ、
それらを漏れなく書きます。
これを完成すれば、モデルをどのようにバラすかがはっきり分かってきます。
ここだとModel>Stages>Blocks>Pathsという具合にバラせばいい、というのが見えてきます。


class ResPath(nn.Module):
    def __init__(self, in_chan, out_chan, stride):
        super(ResPath, self).__init__()
        self.conv1 = nn.Conv2d(in_chan, out_chan, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_chan)
        self.conv2 = nn.Conv2d(out_chan, out_chan, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_chan)

    def forward(self, x):
        x = self.bn1(self.conv1(F.relu(x, inplace=False)))
        x = self.bn2(self.conv2(F.relu(x, inplace=False)))
        return x

class DownSamplePath(nn.Module):
    def __init__(self, in_chan, out_chan):
        super(DownSamplePath, self).__init__()
        self.conv1 = nn.Conv2d(in_chan, in_chan, 1, 1, 0, bias=False)
        self.conv2 = nn.Conv2d(in_chan, in_chan, 1, 1, 0, bias=False)  
        self.bn1 = nn.BatchNorm2d(out_chan)

    def forward(self, x):
        x = F.relu(x, inplace=False)

        x1 = F.avg_pool2d(x, 1, stride=2, padding=0)
        x1 = self.conv1(x1)

        x2 = F.pad(x[:, :, 1:, 1:], (0, 1, 0, 1)) 
        x2 = F.avg_pool2d(x2, 1, stride=2, padding=0)
        x2 = self.conv2(x2)

        y = torch.cat([x1, x2], dim=1)
        y = self.bn1(y)
        return y

まずは一番細かい、Pathsレベルの構造を定義します。
2種類あって、まずはResPathという、普通のResNetの中のやつですね。
Relu>Conv>BNという順で、変哲のないものです。
F.reluのinplaceは計算のとき元の値のコピーを残すかの設定です。Trueにしてエラーが出てこないなら、スピードアップができます。

DownSamplePathは二手に分かれて、stride 2のaverage poolingによるダウンサンプリングを行います。
一つ目はインプットをそのままで、二つ目はインプットを右と下に1pixelだけずらしてからダウンサンプリングをします。
これにより同じ画像の違う特徴が抽出できるはずです。
抽出したらconcatをして、ここは完了します。


class ShakeFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x1, x2, alpha, beta):
        ctx.save_for_backward(x1, x2, alpha, beta)
        y = x1 * alpha + x2 * (1 - alpha)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x1, x2, alpha, beta = ctx.saved_tensors
        grad_x1 = grad_x2 = grad_alpha = grad_beta = None
        grad_x1 = grad_output * beta
        grad_x2 = grad_output * (1 - beta)

        return grad_x1, grad_x2, grad_alpha, grad_beta
ShakeFunc = ShakeFunc.apply

class ResBlock(nn.Module):
    def __init__(self, in_chan, out_chan, stride):
        super(ResBlock, self).__init__()
        self.path1 = ResPath(in_chan, out_chan, stride)
        self.path2 = ResPath(in_chan, out_chan, stride)

        self.down_sample = nn.Sequential()
        if in_chan != out_chan:
            self.down_sample.add_module("DownSamplePath", DownSamplePath(in_chan, out_chan))

    def get_alpha_beta(self, batch_size, is_training, device):
        """only Shake-Shake-Image is implemented here"""
        if is_training:
            alpha = torch.rand((batch_size, 1, 1, 1))
            beta = torch.rand((batch_size, 1, 1, 1))
        else:
            alpha = torch.ones((batch_size, 1, 1, 1)) * 0.5
            beta = torch.ones((batch_size, 1, 1, 1)) * 0.5 
        alpha = alpha.to(device)
        beta = beta.to(device)
        return alpha, beta

    def forward(self, x):
        x1 = self.path1(x)
        x2 = self.path2(x)
        alpha, beta = self.get_alpha_beta(x.size(0), self.training, x.device) #self.trainingは現在訓練モードであるかどうかを示すbooleanです。
        y = ShakeFunc(x1, x2, alpha, beta)
        return self.down_sample(x) + y    

次はBlock構造です。
まずはResBlock自体について見てみましょう。
BlockごとにResPathが二つ、DownSamplePathが一つで、計三つのPathがあります。
forwardを見てみるとその通りに書いてあります。
ただ、ResPathの両PathがDownSamplePathと合算する前に、ある作業をしなければなりません。
それがShake-Shakeの肝となるBranchごとの重みを与えることです。
それを行うのはセッション冒頭にあるShakeFuncです。

このShakeFuncに対しては正直、どうしたものかと手こずりました。
図を見たら分かると思いますが、forwardの時がalphaと、逆伝播の時はbetaと、異なる定数を重みにしてしまいます。
forwardは簡単だったが、backwardの書き方が分かりませんでした。ここはhysts氏の実装を参考にしました。
どうやらtorch.autograd.Functionごと継承して、F.reluやF.tanhみたいなfunctionを自分で作れるようです。
forward/backwardをそれぞれ定義したら、順/逆伝播も対応してくれます。
ctxという一時的なメモリーみたいな変数も作ることになります。

コードにあるStatic methodとtorch.autograd.Functionに関しては、
この記事Pytorch Documentation 1Pytorch Documentation 2にも参考しました。
正直ここは今でも完全に把握しているわけでもないんです。

ちなみに、上にあるget_alpha_betaの引数のうち、今が訓練しているかどうかを示すものを求めているわけですが、
自分は最初はmodel.train()を与えました。これのせいでバグって丸一日費やしてやっと解消しました。
正しくはself.trainingです。
先入観を持たず、少しでも不確かなところがあれば、必ずdocumentationを見ましょう。


class ResStage(nn.Module):
    def __init__(self, in_chan, out_chan, stride, n_blocks=4):
        super(ResStage, self).__init__()
        self.stage = nn.Sequential()
        for idx in range(n_blocks):
            if idx == 0:
                self.stage.add_module("block{}".format(idx+1), ResBlock(in_chan, out_chan, stride=stride))
            else:
                self.stage.add_module("block{}".format(idx+1), ResBlock(out_chan, out_chan, stride=1))      

    def forward(self, x):
        x = self.stage(x)
        return x

ここは複数のBlockを連結してStageにします。
わざわざStageのクラスを定義する必要はないかもしれませんが、自分はこの方がやりやすいです。

def initialize_weights(module):
    if isinstance(module, nn.Conv2d):
        nn.init.kaiming_normal_(module.weight.data, mode='fan_out')
    elif isinstance(module, nn.BatchNorm2d):
        module.weight.data.fill_(1)
        module.bias.data.zero_()
    elif isinstance(module, nn.Linear):
        module.bias.data.zero_()

ラストに入る前に、重みの初期化も定義しておきます。
ちなみに、BatchNormをやっているからか、この部分がなくても結果が全く変わらなかったです。


class Shakeshake(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, n_channels[0], 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(n_channels[0])
        self.stage1 = ResStage(n_channels[0], n_channels[0], stride=1)
        self.stage2 = ResStage(n_channels[0], n_channels[1], stride=2)
        self.stage3 = ResStage(n_channels[1], n_channels[2], stride=2)
        self.fc1 = nn.Linear(n_channels[2], n_classes)

        # initialize weights
        self.apply(initialize_weights)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = F.adaptive_max_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

各部品を定義しておいたら、本体はいたってシンプルです。

Optimizerとその他諸々


nepoch = 1800

base_chan = 32
n_channels = [base_chan, base_chan * 2, base_chan * 4]
n_classes = 10

model = Shakeshake(n_channels, n_classes)
model = model.to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.009, momentum=0.9, weight_decay=1e-4, nesterov=True)

T_max = nepoch * train_dataset_size
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)

global_step = 0

論文に従ってここはSGDを使います。学習率については論文では0.2でした。疑いつつ試したが、やはりうまく行かなかったので、小さくしました。ちなみに、0.005-0.01の範囲では最適なようです(少なくともここでは)。momentumやweight_decayについてはげんきゅうされなかったので、穏便に付きました。
schedulerはSGDのような自動的に調節しないoptimizerに対して調整を行う機関です。論文に従ってCosine Annealingを使います。Mini batchごとに調整してくれるようにできています。

## auxilliary functions
def get_config():
    config = OrderedDict({'name':time_run,
                          'nepoch':nepoch,
                          'base_chan':base_chan,
                          'loss_func':str(loss_func),
                          'optimizer':str(optimizer),
                          'scheduler':str(scheduler.__dict__),
                          'model':str(model),
                          'train_loader':str(train_loader.__dict__), 
                          'test_loader':str(test_loader.__dict__)})
    return config

最後にログを残すためのものも定義します。これだけ残せば何を実行したのかがはっきりします。

Training Loop


def train(epoch, model, loss_func, train_loader, optimizer, scheduler, writer):
    global global_step

    since = time.time() #所要時間も記録します。

    model.train()

    running_loss = 0.0            
    running_correct = 0

    for i, data in enumerate(train_loader):
        global_step +=1
        inputs, labels = data

        if i == 0:
            image = torchvision.utils.make_grid(
                inputs, normalize=True, scale_each=True)
            writer.add_image('Train/Image', image, epoch)

        scheduler.step()        

        writer.add_scalar('Train/LearningRate',
                            scheduler.get_lr()[0], global_step)

        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_correct += np.double(torch.sum(predicted == labels.data))

        writer.add_scalar('Train/RunningLoss', running_loss , global_step)
        writer.add_scalar('Train/RunningAccuracy', running_correct , global_step)

    train_loss = running_loss / train_dataset_size
    train_acc = running_correct / train_dataset_size

    elapsed = time.time() - since

    writer.add_scalar('Train/Loss', train_loss, epoch)
    writer.add_scalar('Train/Accuracy', train_acc, epoch)
    writer.add_scalar('Train/Time', elapsed, epoch)

    train_log = OrderedDict({
    'epoch':
    epoch,
    'train':
    OrderedDict({
        'loss': train_loss,
        'accuracy': train_acc,
        'time': elapsed,
        }),
    })

    print('Train Loss:{:.6f} Accuracy:{:.4f} \nTime:{:.1f}s'.format(train_loss, train_acc, elapsed))        

    return train_log, train_acc

前編のポンコツと比べてはかなり増えてきましたが、大半がログを残すためのものです。
writerはtensorboardXの記録係みたいなもので、add_scalar(名前, 記録する変数, タイム)で記録します。
過剰かもしれませんが、万が一に備え、ロス値自体も文字としてOrderedDictでアウトプットしています。

def test(epoch, model, loss_func, test_loader, writer):
    global global_step
    since = time.time()

    model.eval()

    running_loss = 0.0            
    running_correct = 0

    with torch.no_grad():
        for i, data in enumerate(test_loader):
            global_step +=1
            inputs, labels = data

            if i == 0:
                image = torchvision.utils.make_grid(
                    inputs, normalize=True, scale_each=True)
                writer.add_image('Test/Image', image, epoch)
                #tensorboardXのwriterはscalarだけでなく、画像やhistogramも記録できます。

            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)

            running_loss += loss.item()
            running_correct += np.double(torch.sum(predicted == labels.data))
    test_loss = running_loss / test_dataset_size
    test_acc = running_correct / test_dataset_size

    elapsed = time.time() - since

    if epoch > 0:
        writer.add_scalar('Test/Loss', test_loss, epoch)
    writer.add_scalar('Test/Accuracy', test_acc, epoch)
    writer.add_scalar('Test/Time', elapsed, epoch)    

    test_log = OrderedDict({
    'epoch':
    epoch,
    'test':
    OrderedDict({
        'loss': test_loss,
        'accuracy': test_acc,
        'time': elapsed,
        }),
    })

    print('Test Loss:{:.6f} Accuracy:{:.4f} \nTime:{:.1f}s'.format(test_loss, test_acc, elapsed))  

    return test_log, test_acc

いよいよ最後


def main():
    print(time_run)

    config = get_config()
    with open(save_dir+'/{}_config.json'.format(time_run), 'w') as fout:
        json.dump(config, fout, indent=2)
  #ここであらゆるパラメータをログとして残します。

    summary(model, (3, 32, 32))
    train_since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    writer = SummaryWriter(save_dir) #writerの初期化

    test(0, model, loss_func, test_loader, writer)

    epoch_logs = []
    train_accs = []
    test_accs = []

    for epoch in range(nepoch):
        print('Epoch {}/{}'.format(epoch, nepoch - 1))
        print('-' * 10)

        train_log, train_acc = train(epoch, model, loss_func, train_loader, optimizer, scheduler, writer)
        test_log, test_acc = test(epoch, model, loss_func, test_loader, writer)

        if test_acc > best_acc:
            best_acc = test_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        print()
        train_accs.append(train_acc)
        test_accs.append(test_acc)
        epoch_log = train_log.copy()
        epoch_log.update(test_log)
        epoch_logs.append(epoch_log)
        with open(save_dir+'/{}_log.json'.format(time_run), 'w') as fout:
            json.dump(epoch_logs, fout, indent=2)

        if epoch % 50 == 0 and epoch !=0: #50epochごとにモデルをセーブします。念のため。
            torch.save(model.state_dict(), save_dir+'/{}_ep{}_acc{:.4f}.pth'.format(time_run, epoch, test_acc))

    accuracies = {'train_accuracy':np.array(train_accs), 'test_accuracy':np.array(test_accs)}
    np.save(save_dir+'/{}_accs.npy'.format(time_run), accuracies)

    train_elapsed = time.time() - train_since
    print('Training complete in {:.0f}m {:.0f}s'.format(train_elapsed // 60, train_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))

    # load and save best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), save_dir+'/{}_bestacc{:.4f}.pth'.format(time_run, best_acc))

if __name__ == '__main__':
    main()

これでようやく完成しました。パラメータ数が2,932,266です。

Training complete in 1723m 44s
Best val Acc: 0.9498

いろいろ試行錯誤しながら、修正してきたモデルだが、
結果、95%正解率が出ました。論文では同じアーキテクトで96.45%出せましたが、初めての挑戦としてはこの結果でも納得できたと思います。

まとめ

  • 初心者はとりあえず書いてみましょう。
  • model.eval()を忘れないようにしましょう。
  • 複雑なモデルは部品ごとに分解してから書きましょう。
  • 関数や引数について不明なところがあれば、必ずdocumentationを見ましょう。
  • ログは貪欲に残しましょう。

最後に

楽しかったです。これからもいろいろ挑戦して、記録も残したいと思います。よろしくお願いします。

11
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
7