LoginSignup
2
0

More than 5 years have passed since last update.

( skimage.transform による ) Distorted MNIST 作成に関するメモ

Last updated at Posted at 2018-06-08

ポイント

  • skimage.transform を用いて、distorted MNIST を作成。
  • 今後、distorted MNIST を使用し、モデルのパフォーマンス検証を実施。

レファレンス

1. scikit-image

データ

MNIST handwritten digits

###
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('***/mnist', \
                                     one_hot = True)

### 
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
images, targets = mnist.data, mnist.target
images = images / 255.0

n = len(images)
s = np.random.permutation(n)
nn = int(n / 5)
images_train, images_test = images[s[nn:]], images[s[:nn]]
targets_train, targets_test = targets[s[nn:]], targets[s[:nn]]

サンプルコード1

import skimage.transform

# initial 
index = np.random.randint(1000)
image_initial = np.reshape(mnist.train.images[index], \
                                      [28, 28])

# rotation
angle = 90
theta = 1.0 * angle / 180 * np.pi 
matrix_r = np.array([[np.cos(theta), -np.sin(theta), \
              14 * (1 + np.sin(theta) - np.cos(theta))], \
              [np.sin(theta), np.cos(theta), \
              14 * (1 - np.sin(theta) - np.cos(theta))], \
              [0, 0, 1]])
tform_r_1 = skimage.transform.AffineTransform(matrix = \
                                   matrix_r)
image_r_1 = skimage.transform.warp(image_initial, tform_r_1)

src_r = np.array([[21, 7], [21, 21], [7, 21], [7, 7]])
dst_r = np.array([[7, 7], [21, 7], [21, 21], [7, 21]])
tform_r_2 = skimage.transform.ProjectiveTransform()
tform_r_2.estimate(dst_r, src_r)
image_r_2 = skimage.transform.warp(image_initial, tform_r_2)

# scale
scale = 2
s = 1.0 / scale
matrix_s = np.array([[s, 0, -14 * (s - 1)],
                     [0, s, -14 * (s - 1)],
                     [0, 0, 1]])
tform_s_1 = skimage.transform.AffineTransform(matrix = \
                      matrix_s)
image_s_1 = skimage.transform.warp(image_initial, tform_s_1)

src_s = np.array([[10.5, 10.5], [17.5, 10.5], \
                      [17.5, 17.5], [10.5, 17.5]])
dst_s = np.array([[7, 7], [21, 7], [21, 21], [7, 21]])
tform_s_2 = skimage.transform.ProjectiveTransform()
tform_s_2.estimate(dst_s, src_s)
image_s_2 = skimage.transform.warp(image_initial, tform_s_2)

# translation
t_x = 5
t_y = 2
matrix_t = np.array([[1, 0, t_x],
                     [0, 1, t_y],
                     [0, 0, 1]])
tform_t_1 = skimage.transform.AffineTransform(matrix = \
                      matrix_t)
image_t_1 = skimage.transform.warp(image_initial, tform_t_1)

src_t = np.array([[7 + t_x, 7 + t_y], [21 + t_x, 7 + t_y], \
              [21 + t_x, 21 + t_y], [7 + t_x, 21 + t_y]])
dst_t = np.array([[7, 7], [21, 7], [21, 21], [7, 21]])
tform_t_2 = skimage.transform.ProjectiveTransform()
tform_t_2.estimate(dst_t, src_t)
image_t_2 = skimage.transform.warp(image_initial, tform_t_2)

# flip
matrix_f = np.array([[-1, 0, 28],
                     [0, 1, 0],
                     [0, 0, 1]])
tform_f_1 = skimage.transform.AffineTransform(matrix = \
                     matrix_f)
image_f_1 = skimage.transform.warp(image_initial, tform_f_1)

src_f = np.array([[21, 7], [7, 7], [7, 21], [21, 21]])
dst_f = np.array([[7, 7], [21, 7], [21, 21], [7, 21]])
tform_f_2 = skimage.transform.ProjectiveTransform()
tform_f_2.estimate(dst_f, src_f)
image_f_2 = skimage.transform.warp(image_initial, tform_f_2)


fig = plt.figure(figsize = (8, 5))

ax = fig.add_subplot(3, 4, 1)
ax.imshow(image_initial, cmap = 'gray')
ax.set_title('Initial')
ax.set_axis_off()

ax_r_1 = fig.add_subplot(3, 4, 5)
ax_r_1.imshow(image_r_1, cmap = 'gray')
ax_r_1.set_title('Rotation 1')
ax_r_1.set_axis_off()

ax_r_2 = fig.add_subplot(3, 4, 6)
ax_r_2.imshow(image_r_2, cmap = 'gray')
ax_r_2.set_title('Rotation 2')
ax_r_2.set_axis_off()

ax_s_1 = fig.add_subplot(3, 4, 7)
ax_s_1.imshow(image_s_1, cmap = 'gray')
ax_s_1.set_title('Scale 1')
ax_s_1.set_axis_off()

ax_s_2 = fig.add_subplot(3, 4, 8)
ax_s_2.imshow(image_s_2, cmap = 'gray')
ax_s_2.set_title('Scale 2')
ax_s_2.set_axis_off()

ax_t_1 = fig.add_subplot(3, 4, 9)
ax_t_1.imshow(image_t_1, cmap = 'gray')
ax_t_1.set_title('Translation 1')
ax_t_1.set_axis_off()

ax_t_2 = fig.add_subplot(3, 4, 10)
ax_t_2.imshow(image_t_2, cmap = 'gray')
ax_t_2.set_title('Translation 2')
ax_t_2.set_axis_off()

ax_f_1 = fig.add_subplot(3, 4, 11)
ax_f_1.imshow(image_f_1, cmap = 'gray')
ax_f_1.set_title('Flip 1')
ax_f_1.set_axis_off()

ax_f_2 = fig.add_subplot(3, 4, 12)
ax_f_2.imshow(image_f_2, cmap = 'gray')
ax_f_2.set_title('Flip 2')
ax_f_2.set_axis_off()

plt.show()

アウトプット1

image.png

サンプルコード2

import skimage.transform

def distorted_image_generator(original_images):

  distorted_images = []

  for i in range(len(original_images)):
    image = np.reshape(original_images[i], [28, 28])

    s = 1.0 
    if (np.random.uniform(low = 0.0, high = 1.0) < 0.5):
      s = np.random.uniform(low = 1.0, high = 1.5)
    angle = np.random.uniform(360.0)
    tx = np.random.randint(5)
    ty = np.random.randint(5)

    # rotate
    theta = 1.0 * angle / 180 * np.pi 
    matrix_r = np.array([[np.cos(theta), -np.sin(theta), \
            14 * (1 + np.sin(theta) - np.cos(theta))], \
            [np.sin(theta), np.cos(theta), 14 * (1 - \
            np.sin(theta) - np.cos(theta))], [0, 0, 1]])
    tform_r = skimage.transform.AffineTransform(matrix = \
            matrix_r)
    image_distorted = skimage.transform.warp(image, tform_r)

    # scale
    matrix_s = np.array([[s, 0, -14 * (s - 1)],
                         [0, s, -14 * (s - 1)],
                         [0, 0, 1]])
    tform_s = skimage.transform.AffineTransform(matrix = \
               matrix_s)
    image_distorted = skimage.transform.warp( \
               image_distorted,  tform_s)

    # translation
    matrix_t = np.array([[1, 0, tx],
                         [0, 1, ty],
                         [0, 0, 1]])
    tform_t = skimage.transform.AffineTransform(matrix = \
                   matrix_t)
    image_distorted = skimage.transform.warp( \ 
                   image_distorted, tform_t)
    image_distorted = np.reshape(image_distorted, [28 * 28])

    distorted_images.append(image_distorted)

  return distorted_images

original_images_train = mnist.train.images[0:1000]
distorted_images_train = \
       distorted_image_generator(original_images_train)
print (np.shape(distorted_images_train))

original_images_test = mnist.test.images[0:1000]
distorted_images_test = \
        distorted_image_generator(original_images_test)
print (np.shape(distorted_images_test))

index = np.random.randint(len(distorted_images_test))

fig = plt.figure(figsize = (4, 6))
ax1 = fig.add_subplot(1, 2, 1)
ax1.imshow(np.reshape(original_images_test[index], \
                  [28, 28]), cmap = 'gray')
ax1.set_title('Original')
ax1.set_axis_off()

ax2 = fig.add_subplot(1, 2, 2)
ax2.imshow(np.reshape(distorted_images_test[index], \
                  [28, 28]), cmap = 'gray')
ax2.set_title('Distorted')
ax2.set_axis_off()

plt.show()

アウトプット2

image.png


image.png

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0