32
37

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

特徴空間のマハラノビス距離を利用した継続学習を実装した

Posted at

継続学習

継続学習とはモデルが長い期間次々と与えられる新しいデータに対して連続的に学習すること
詳しくは継続学習についてまとめたスライド参照

今回は分布外検知の手法を利用した継続学習の論文を実装する
論文タイトル: A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks
論文の概要: 分布外データを新しいクラスのデータとみなし,深層モデルの特徴空間上でガウシアンフィッティングを行い,古いクラスと新しいクラスの平均ベクトルとのマハラノビス距離をもとにテストデータを分類する.
論文の詳しい解説はスライドの2つ目にかいた

この論文の強みはこれまで学習したクラスの精度をほぼ落とさずに継続学習が可能であり,新たに追加されたクラスのデータを利用したDNNのパラメータの再学習が必要がないことである.
そのため,クラス数が追加で増えても非常に高速に学習ができる.

実装の問題設定

  • CIFAR10の0~4番目のクラスが最初に訓練データとして与えられ,追加で5~9番目のクラスの訓練データが次々と与えられる
  • 追加で与えられたクラスのデータのみを使って学習し,最初に与えられたクラスのデータを使って再学習することはできない

実装解説

まずは準備

import os 

import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms as T
from torchvision.datasets import CIFAR10

# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
# 上記レポジトリのmodelとutilsを同じ階層におくこと
from model import EfficientNet 
from tqdm import tqdm
plt.style.use("ggplot")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CIFAR10のクラスを追加できるようにdataloaderを返す関数を用意した

def return_data_loader(classes, train=True, batch_size=128):
    transform = []
    transform.append(T.Resize((64, 64))) # efmodelを使うためにresizeする必要がある
    transform.append(T.ToTensor())
    transform = T.Compose(transform)

    dataset = CIFAR10("./data", train=train, download=True, transform=transform)
    targets = np.array(dataset.targets)
    mask = np.array([t in classes for t in targets])
    dataset.data = dataset.data[mask]
    dataset.targets = targets[mask]
    
    data_loader = DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=train)
    return data_loader

モデルとしてefficientnetの一番小さいモデルを利用する

まず,図の青の矢印部分として,普通に5クラスの識別モデルを学習する
次に,赤の矢印部分として一層前の特徴をクラスごとにガウス分布で近似する
スクリーンショット 2020-10-05 9.08.56.png

NCLASS = 5 # 初期クラス
classes = np.arange(NCLASS)
model = 'efficientnet-b0'
weight_dir = "."

clf = EfficientNet.from_name(model)
clf._fc = torch.nn.Linear(clf._fc.in_features, NCLASS)
clf = clf.to(device)
clf.train()
train_loader = return_data_loader(classes=classes, train=True)

lr = 0.001
epoch_num = 50
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clf.parameters(), lr=lr)

for epoch in tqdm(range(epoch_num)):
    train_loss = 0
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        logit = clf(x)
        loss = criterion(logit, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)
torch.save(clf.state_dict(), os.path.join(weight_dir, 'weight.pth'))

test_loader = return_data_loader(range(10), train=False)
clf.load_state_dict(torch.load(os.path.join(weight_dir, 'weight.pth')))
clf.eval()

pred = []
true = []
for x, y in test_loader:
    with torch.no_grad():
            pred.extend(clf(x.to(device)).max(1)[1].detach().cpu().numpy())
            true.extend(y.numpy())
print(accuracy_score(true, pred))
print(confusion_matrix(true, pred))

まずは,普通に識別モデルを学習したときの混合行列と正解率を出力した
0~4番目のクラスしか学習につかっていないため,当然5~9番目のクラスの予測はできず,無理やり0~4番目のクラスを予測するようになっている.

0.4279 # 正解率 
[[877  27  47  35  14   0   0   0   0   0]
 [ 14 972   3   8   3   0   0   0   0   0]
 [ 51   7 785  81  76   0   0   0   0   0]
 [ 20  18 107 780  75   0   0   0   0   0]
 [ 13   2  58  62 865   0   0   0   0   0]
 [ 13  12 226 640 109   0   0   0   0   0]
 [ 26  55 232 477 210   0   0   0   0   0]
 [ 47  21 188 230 514   0   0   0   0   0]
 [604 214  53  95  34   0   0   0   0   0]
 [160 705  43  78  14   0   0   0   0   0]]

次に上の図の赤の矢印部分の実装のために,特徴量の平均と共分散を計算する
スクリーンショット 2020-10-05 9.29.18.png

def ext_feature(x):
    z = clf.extract_features(x)
    z = clf._avg_pooling(z)
    z = z.flatten(start_dim=1)
    return z.detach().cpu().numpy()

train_loaders = [return_data_loader(classes=[c], train=True) for c in range(10)]

z_mean = []
z_var = 0
target_count = []

for c in tqdm(range(NCLASS)): # 既存のクラス
    N = len(train_loaders[c].dataset) # 各クラス数の保持
    target_count.append(N)
    
    with torch.no_grad():
        # 平均の計算
        new_z_mean = 0
        for x, _ in train_loaders[c]:
            x = x.to(device)
            new_z_mean += ext_feature(x).sum(0) / N
        z_mean.append(new_z_mean)

        # 分散の計算
        for x, _ in train_loaders[c]:
            x = x.to(device)    
            z_var += (ext_feature(x) - new_z_mean).T.dot(ext_feature(x) - new_z_mean) / N

C = len(z_mean)
z_var /=  C
z_mean = np.array(z_mean)
target_count = np.array(target_count)

平均と共分散がもとまったらマハラノビス距離を利用して最終層の全結合層なしに分類することが可能である
実装ではベイズの定理を利用している
スクリーンショット 2020-10-05 9.32.07.png
ここで,$\beta_c$はクラスのデータの個数である

z_var_inv = np.linalg.inv(z_var + np.eye(z_mean.shape[1])*1e-6)  
# 逆行列が不安定になるのを防ぐために正則化をいれる
A = z_mean.dot(z_var_inv) # 分子のexpの中身の1項目
B = (A*z_mean).sum(1) * 0.5 # 2項目
beta = np.log(target_count) # 3項目

accs = []
pred = []
true = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        pred.extend((A.dot(ext_feature(x).T) - B[:, None] + beta[:, None]).argmax(0))
        true.extend(y.numpy())
acc = accuracy_score(true, pred)        
print(acc)
accs.append(acc)
confusion_matrix(true, pred)

以下の結果より,全結合層使わなくても,ほとんど同じくらいの正解率を達成できることがわかった

0.4273 # 正解率
array([[899,  17,  43,  29,  12,   0,   0,   0,   0,   0],
       [ 25, 958,   6,   9,   2,   0,   0,   0,   0,   0],
       [ 55,   6, 785,  86,  68,   0,   0,   0,   0,   0],
       [ 29,  15, 109, 773,  74,   0,   0,   0,   0,   0],
       [ 23,   2,  55,  62, 858,   0,   0,   0,   0,   0],
       [ 22,   6, 227, 641, 104,   0,   0,   0,   0,   0],
       [ 34,  39, 256, 468, 203,   0,   0,   0,   0,   0],
       [ 71,  16, 199, 214, 500,   0,   0,   0,   0,   0],
       [653, 182,  53,  84,  28,   0,   0,   0,   0,   0],
       [221, 645,  42,  78,  14,   0,   0,   0,   0,   0]])

継続学習の実装

モデルのパラメータの訓練をすることなく,新しいデータの平均と分散をもとに,全てのクラスのテストデータを分類することが目標となる

アルゴリズムの概要は以下のようになる
スクリーンショット 2020-10-05 9.39.07.png

for c in tqdm(range(NCLASS, 10)): # 新規クラス
    N = len(train_loaders[c].dataset)

    with torch.no_grad():
        # 平均の計算
        new_z_mean = 0        
        for x, _ in train_loaders[c]:
            x = x.to(device)
            new_z_mean += ext_feature(x).sum(0) / N 


        # 分散の計算
        new_z_var = 0
        for x, _ in train_loaders[c]:
            x = x.to(device)    
            new_z_var += (ext_feature(x) - new_z_mean).T.dot(ext_feature(x) - new_z_mean) / N

    #平均と分散の更新
    C = len(target_count)
    z_mean = np.concatenate([z_mean, new_z_mean[None, :]])
    z_var = z_var*C/(C+1) + new_z_var/(C+1)
    target_count = np.append(target_count, N)

    z_var_inv = np.linalg.inv(z_var + np.eye(z_mean.shape[1])*1e-6)
    A = z_mean.dot(z_var_inv) 
    B = (A*z_mean).sum(1) * 0.5
    beta = np.log(target_count)
    pred = []
    true = []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            pred.extend((A.dot(ext_feature(x).T) - B[:, None] + beta[:, None]).argmax(0))
            true.extend(y.numpy())
    acc = accuracy_score(true, pred)
    accs.append(acc)
    print(acc)

最終的な結果は以下である

0.4974 # 正解率
array([[635,   1,  18,   4,   2,  14,   9,  36, 260,  21],
       [  1, 761,   0,   1,   0,   0,   8,   3,  21, 205],
       [ 20,   0, 581,  12,   8,  97, 105, 135,  35,   7],
       [  5,   0,  22, 450,  13, 256, 147,  60,  29,  18],
       [  2,   1,  16,  10, 555,  30,  63, 302,  20,   1],
       [  1,   0,  57, 288,  22, 325, 173, 106,  22,   6],
       [  5,   0,  49, 139,  36, 182, 350, 161,  35,  43],
       [  5,   2,  34,  50, 131, 104, 158, 446,  58,  12],
       [226,  26,  13,  11,   3,  22,  58,  41, 430, 170],
       [ 17, 250,   6,   5,   0,   8,  69,  16, 188, 441]])
plt.title("accuracy")
plt.plot(accs)
plt.show()

x軸は加えたクラス数を意味する.
最終的に10クラスの訓練データが与えられたときの正解率は,5クラスだけ与えられたときよりも,0.1程度あがっていることがわかる.
download-6.png

32
37
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
37

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?