6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Random Erasing Data Augmentation

Last updated at Posted at 2017-11-23

下記の論文を参考に、Random Erasing Data Augmentationを実装してみました。

Zhun Zhong, et al. (2017) Random Erasing Data Augmentation

こういう場で論文を引用するのは初めてなので、やり方がおかしかったり、著作権に引っかかっているなどあれば教えていただければ幸いです。

Data Augmentationの一種で、画像のクラス分類、物体検出、人物照合において有効性があるそうです。
(精度を測るところまでやりたかったのですが、手すきの記事なのでそこまでやってないです...すいません)

アルゴリズム

*Input*  
  Input image: I;  
  Image size: W * H;  
  Area of images: S;  
  Erasing probability: p;  
  Erasing area ratio range: (s_l, s_h);  
  Erasing aspect ratio range: (r1, r2);  

*Output*  
  Erased Image: I_erased;  
  
  
*Initialization*  
  p1 <- rand(0, 1);  
  
*Algorithm*  
  if p1 > p then  
    I_erased <- I;  
    return I_erased.  
  else 
    while True do  
      S_e <- rand(s_l, s_h) * S;  
      r_e <- rand(r1, r2);  
      
      H_e <- sqrt(S_e * r_e);  
      W_e <- sqrt(S_e / r_e);  
      
      x_e <- rand(0, W);  
      y_e <- rand(0, H);  
      
      if x_e + W_e <+ W and y_e + H_e <= H then
        I_e <- (x_e, y_e, x_e + W_e, y_e + H_e);  
        I(I_e) <- Rand(0, 255);  
        I_erased <- I;  
        return I_erased.  
      end
    end  
  end

コード

%matplotlib inline

# 無駄なimportがあるかもしれません
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from sklearn.utils import shuffle
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split

from keras.datasets import cifar10

rng = np.random.RandomState(1234)
random_state = 42

(cifar_X_1, cifar_y_1), (cifar_X_2, cifar_y_2) = cifar10.load_data()  # 一つ目がtrain, 二つ目がtest

cifar_X = np.r_[cifar_X_1, cifar_X_2]
cifar_y = np.r_[cifar_y_1, cifar_y_2]

cifar_X = cifar_X.astype('float32') / 255
cifar_y = np.eye(10)[cifar_y.astype('int32').flatten()]

train_X, test_X, train_y, test_y = train_test_split(cifar_X, cifar_y, test_size=10000, random_state=random_state)
train_X, valid_X, train_y, valid_y = train_test_split(train_X, train_y, test_size=10000, random_state=random_state)


# 表示しやすいように関数にしています
def show_cifar10_images(dataset):
    fig = plt.figure(figsize=(9, 15))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=0.5, hspace=0.05,
                        wspace=0.05)
    for i in range(81):
        ax = fig.add_subplot(9, 9, i + 1, xticks=[], yticks=[])
        ax.imshow(dataset[i])


# ここが今回のコード
# パラメタは論文に書いてある通りにしました。
def random_erasing(img, p = 0.5, s_l = 0.02, s_h = 0.4, r1 = 0.3, r2 = 1. / 0.3):
    p1 = np.random.uniform(0,1)
    if p1 < p:
        return img
    else:
        H = img.shape[0]
        W = img.shape[1]
        S = H * W
        while True:
            S_e = S * np.random.uniform(low=s_l, high=s_h)
            r_e = np.random.uniform(low=r1, high=r2)
            
            H_e = np.sqrt(S_e * r_e)
            W_e = np.sqrt(S_e / r_e)
            
            x_e = np.random.randint(0, W)
            y_e = np.random.randint(0, H)
            
            if x_e + W_e <= W and y_e + H_e <= H:
                img_erased = np.copy(img)
                img_erased[y_e:int(y_e + H_e + 1), x_e:int(x_e + W_e + 1), :] = np.random.uniform(0, 1)
                return img_erased


# 表示してみる
train_X_erased = np.copy(train_X)
for i in range(train_X_erased.shape[0]):
    train_X_erased[i] = random_erasing(train_X_erased[i])

show_cifar10_images(train_X_erased)

結果

random_erase_DA.png

うまく一部がeraseできています。

6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?