2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

k-WTAは敵対的摂動に対して頑健である。その実験

Last updated at Posted at 2021-12-14

概要

Enhancing Adversarial Defense by k-Winners-Take-All

k-WTAは入力ベクトルの中で、上位k個の大きい数値のみをそのまま残し、それ以外を0にするという活性化関数。k-WTAは入力に対して不連続であるため、ネットワークを騙す摂動が計算しにくい。また、パラメータ空間においては、不連続性が小さいため、ネットワークの学習は従来通り可能。

背景

DNNは敵対的摂動と呼ばれる小さなノイズによって簡単に騙されることがわかっています。

adv_sample.png

DNNの社会応用のためには、このような攻撃からDNNを防御する必要があります。

提案手法

提案手法は非常にシンプルで、k-WTAという活性化関数をDNNに組み込むことです。

10235F1D-DC67-48AB-8048-E02443A61550.jpeg

k-WTAは入力ベクトルの中で、上位k個の大きい数値のみをそのまま残し、それ以外を0にするという活性化関数です。どれだけの数を残すかを表すパラメータkは、ハイパーパラメータであり、ネットワークを構築する際にあらかじめ決めておく値です。

特性

k-WTAには以下のような特性があります。

1. 入力に対して不連続

Adversarial AttackはDNNの入力に対する勾配を必要とする。例えば、FGSMは以下のような式で求めます。

$$\rho(x)=\epsilon\times sign(\nabla_xL(x,y))$$

k-WTAでは入力に対して、不連続な関数となっています。

下図を見ると、ある層からの入力$x$がk-WTAに入力され、その後重みWによって変換されています。この時、k-WTAによって、重みWのi列目の成分が寄与するのかの一貫性を失います。

摂動は入力に対する勾配情報をもとに、繰り返し更新されるため、一貫性を失うことで摂動の計算を困難になります。

そのため、k-WTAはAdversarial Attackに有効となります

08E94E33-816E-4882-B072-C75713895B85.jpeg

2. パラメータの更新には影響しない

上述したように、k-WTAを用いたネットワークでは入力に対して不連続であるため、入力の勾配の計算は困難です。ではパラメータの更新、つまりネットワークの学習はどうなるのか。

答えは、「従来通りに学習できる」です。理由として以下の点が挙げられます。

  • パラメータ空間は入力空間より非常に大きいため、損失は入力に対しては不連続であるがパラメータ空間では不連続の部分が小さくなるため

  • 任意のデータ$x_i$と$x_j$に対して、ある条件下で以下の定理が成り立つため

    $$A(Wx_i+b)\cap A(Wx_j+b)=\varnothing$$

    $A$は活性化のパターンを示したもの。つまり、各データによるパラメータの更新は独立に行われます。

3. 既存手法と同じ計算量

k-WTAはベクトルが与えられた時に、その中でk番目に大きい数値を見つけるという計算を行います。この計算時間はベクトルの長さに比例します。同様にReLUもベクトルの長さに比例する計算を行います。つまり、k-WTAを用いても、計算量の変化はありません。

実験1

CIFAR10を用いて、k-WTAあり/なしのDNNに対して敵対的攻撃(FGSM)をして検証します。

実装方法

k-WTAの自作レイヤーは以下の通りです。

class KWTA(tf.keras.layers.Layer):
  def __init__(self, k, **kwargs) :
    super().__init__(**kwargs)
    self.k = k

  def call(self, x):
    topk = tf.math.top_k(x[0], k=self.k)   
    topk_min = K.min(topk.values)        
    comp = tf.dtypes.cast(x >= topk_min, tf.float32)   
    return tf.math.multiply(x,comp)       

ソースコード

Google Colaboratory

結果

以下の表の数値はAccuracyの値

Input imgs ReLU k-WTA
clean test 0.8335 0.825
attacked test (ε=0.10) 0.1024 0.1309

ReLUモデルよりはロバスト?

実験2

実験1の結果が少し微妙だったので、モデルを深くして再度検証。

ソースコード

Google Colaboratory

結果

以下の表の数値はAccuracyの値

Input imgs ReLU k-WTA
clean test 0.8071 0.8026
attacked test (ε=0.10) 0.1025 0.1237

上の実験の結果との一貫性あり。

実験3

今までの実験では1Dのk-WTAのみを実装していたため、Conv2DについてはReLUのままでした。

そこで、Conv2Dに対応した2Dのk-WTAを実装し、全てのReLUをk-WTAに置き換えて検証してみます。

モデルの深さは実験1と同じ。

実装方法

class KWTA2D(tf.keras.layers.Layer):
    def __init__(self, k=None, **kwargs):
        super().__init__(**kwargs)
        self.k = k
        
    def call(self, x):
        n = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        if self.k == None:
            k = n * 7 // 10
        else:
            k = self.k
        x_flatten = tf.reshape(x[0],  [n])
        topk = tf.math.top_k(x_flatten, k=k)
        topk_min = K.min(topk.values) 
        comp = tf.dtypes.cast(x >= topk_min, tf.float32)  
        return tf.math.multiply(x,comp)

チャネルを含めて全て平坦化させて、上位k個を選択しています。

kについては全体の7/10にすることにしました。

また、Adamのパラメータを少し設定しました。今まで通りに設定せずに学習させると、25epochあたりからaccuracyが急激に下がった(0.9 → 0.006)ので...。おそらく局所最適解か何かに入ったから(でもそんなに下がる?)。

ソースコード

Google Colaboratory

結果

以下の表の数値はAccuracyの値

Input imgs ReLU k-WTA
clean test 0.8071 0.7924
attacked test (ε=0.10) 0.1025 0.1146

学習していく過程を見た感じ、今の設定(ハイパーパラメータ含めて)のままではaccuracy 0.80を超えることはなさそう。

防御性能は...微妙。でもReLUモデルよりはいい。

実験4

実験1~3を踏まえて、ReLUとk-WTAを併用したらどうなるのかを検証。

具体的には、x → k-WTA → ReLU → output

ソースコード

Google Colaboratory

結果

以下の表の数値はAccuracyの値

Input imgs ReLU ReLU + k-WTA
clean test 0.8071 0.8258
attacked test (ε=0.10) 0.1025 0.1048

ReLU+k-WTAモデルが一番下がった結果になりました。

また、ReLUを加えるだけでcleanに対するaccuracyが上がりました。

実験5

kを上位1割になるよう設定しました。

実装方法

class KWTA(tf.keras.layers.Layer):
  def __init__(self, k=None, **kwargs) :
    super().__init__(**kwargs)
    self.k = k

  def call(self, x):
    n = x[0].shape[0]
    if self.k == None:
            k = n * 1 // 10
    else:
            k = self.k
    topk = tf.math.top_k(x[0], k=k)   
    topk_min = K.min(topk.values)         
    comp = tf.dtypes.cast(x >= topk_min, tf.float32)  
    return tf.math.multiply(x,comp) 

class KWTA2D(tf.keras.layers.Layer):
    def __init__(self, k=None, **kwargs):
        super().__init__(**kwargs)
        self.k = k
        
    def call(self, x):
        n = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        if self.k == None:
            k = n * 1 // 10
        else:
            k = self.k
        x_flatten = tf.reshape(x[0],  [n])
        topk = tf.math.top_k(x_flatten, k=k)
        topk_min = K.min(topk.values) 
        comp = tf.dtypes.cast(x >= topk_min, tf.float32)
        return tf.math.multiply(x,comp)

ソースコード

Google Colaboratory

結果

以下の表の数値はAccuracyの値

Input imgs ReLU k-WTA
clean test 0.8071 0.788
attacked test (ε=0.10) 0.1025 0.1319

特に改善はされませんでした。

実験6

実験1~5はCIFAR10を対象に行いました。

ここで、簡単なMNISTを対象に行い、明確な結果を期待して検証してみます。

ソースコード

Google Colaboratory

結果

以下の表の数値はAccuracyの値

Input imgs ReLU k-WTA
clean test 0.9905 0.991
attacked test (ε=0.10) 0.8358 0.9581
attacked test (ε=0.15) 0.5209 0.9145
attacked test (ε=0.20) 0.2452 0.8426

すごい。圧倒的にk-WTAモデルの方がロバストでした。

まとめ

一貫して、ReLUモデルよりk-WTAモデルの方がロバストでした。

MNISTにおいては効果的だったが、CIFAR10では微妙でした。これについては、研究室の先生や先輩方から「論文ではDeepFool等を使っていて、この検証はFGSMを使っているから、論文ほどの差が見られなかったのでは」というご指摘を受けました。時間のある時に検証しておきたいです。

全てのソースコードはこちら

参考

k-WTAについて

Adversarial Attackへの防御の鍵は活性化関数?新たな活性化関数k-WTAの登場!

k-WTAの自作レイヤーのために

GitHub - a554b554/kWTA-Activation

新人工知能プロジェクト

テンソルについて

TensorFlowのTensorオブジェクトに慣れたい - Qiita

テンソルの基礎 | TensorFlow Core

tensorflowのバックエンドについて

tf.math.top_k | TensorFlow Core v2.6.0

tf.cast | TensorFlow Core v2.6.0

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?