What if you want to rotate an image (whether it is 2D or 3D) within a Keras layer?
It is simple to do with Numpy because Numpy has rot90
.
Here is a way to perform 90˚ rotation only using Keras functions which can be wrapped as a Keras layer to be used in a CNN model.
Basic Strategy
90˚ rotation can be done by combining transpose and flip.
keras.backend.transpose
transposes images.
keras.backend.reverse
flips images on the specified axis.
Ex)
import numpy as np
from skimage.data import camera
from keras import backend as K
from matplotlib import pyplot as plt
image = camera().astype(np.float32)
image_tsp = K.transpose(image)
image_cw90 = K.reverse(image_tsp, axes=0) # clock wise
image_ccw90 = K.reverse(image_tsp, axes=1) # counter clock wise
plt.figure()
plt.subplot(221)
plt.imshow(image, cmap='gray')
plt.title('Original')
plt.axis('off')
plt.subplot(222)
plt.imshow(K.eval(image_tsp), cmap='gray')
plt.title('Transposed')
plt.axis('off')
plt.subplot(223)
plt.imshow(K.eval(image_cw90), cmap='gray')
plt.title('Rotated 90˚ cw')
plt.axis('off')
plt.subplot(224)
plt.imshow(K.eval(image_ccw90), cmap='gray')
plt.title('Rotated 90˚ ccw')
plt.axis('off')
plt.show()
Practical implementaion
In a real CNN model, input image data come with 4D or 5D tensors. (including batch and channel dimensions) Therefore, it would be better to use permute_dimensions
rather than transpose
because the latter reverses all the dimensions whereas the former reorders whatever dimensions you want.
In the following example, the dimensions are supposed to be
1: batch
2: depth
3: longitudinal
4: horizontal
5: channel
import numpy as np
from skimage.data import camera
from keras import backend as K
from matplotlib import pyplot as plt
# Create a multi dimensional image
image = camera().astype(np.float32)
image = np.expand_dims(np.expand_dims(image, axis=2), axis=0)
image = np.concatenate((image, 0.5*image), 0)
image = np.concatenate((image, 0.5*image), 0)
image3D = np.expand_dims(image, axis=0)
image3D = np.concatenate((image3D, 0.5*image3D), 0)
image_tsp = K.permute_dimensions(image3D, (0, 1, 3, 2, 4))
image_cw90 = K.reverse(image_tsp, axes=-2) # clock wise
image_ccw90 = K.reverse(image_tsp, axes=-3) # counter clock wise