xarray を使った 中央値フィルタ(メディアンフィルタ)
移動平均などのフィルタを適用する目的で xarray を使ったので記録を残します。
しなければいけなかったことは、多次元アレイのある1方向にのみ移動平均(や移動中央値フィルタ)をかけるものです。
ここでは例として画像データを扱います。
本来は、画像データは画像データ用のライブラリを用いるべきですが、適当なデータがなかったためご了承ください。
import numpy as np
import xarray as xr
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
image = np.array(Image.open('lena.jpg'))
plt.imshow(image)
<matplotlib.image.AxesImage at 0x7f6b807a8898>
ここでは適当にサイズを落として、ソルトアンドペッパーノイズを加えます。
small_image = image[::3, ::3]
noisy_image = small_image.copy()
noisy_image[np.random.randint(0, noisy_image.shape[0], 1000),
np.random.randint(0, noisy_image.shape[1], 1000),
np.random.randint(0, noisy_image.shape[2], 1000)] = 0
noisy_image[np.random.randint(0, noisy_image.shape[0], 1000),
np.random.randint(0, noisy_image.shape[1], 1000),
np.random.randint(0, noisy_image.shape[2], 1000)] = 256
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('original')
plt.imshow(small_image)
plt.subplot(1, 2, 2)
plt.title('noisy')
plt.imshow(noisy_image)
<matplotlib.image.AxesImage at 0x7f6b7eefb6d8>
xr.DataArray
に格納します。縦、横、色方向の軸をそれぞれ x
, y
, c
とします。
data = xr.DataArray(noisy_image, dims=['x', 'y', 'c'])
x
方向に移動平均を取るには、rolling
メソッドを使います。
メソッドのキーワードには対応する軸名と、窓サイズを指定します。
rolling = data.rolling(x=3) # x 方向に3ピクセルの窓を考える。
rolling
DataArrayRolling [window->3,center->False,dim->x]
rolling は、mean
, median
, max
, min
, count
などに対応しています。
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('mean')
plt.imshow(rolling.mean().astype('ubyte')) # `imshow` で表示するために ubyte に変換しています。
plt.subplot(1, 2, 2)
plt.title('median')
plt.imshow(rolling.median().astype('ubyte'))
<matplotlib.image.AxesImage at 0x7f6b7edb7f60>
なお、よく見ると上の2ピクセルが黒くなっていると思います。これは、最初の2ピクセルに対しては移動平均ができないためです。そこには np.nan が入っています。
rolling.mean()
<xarray.DataArray (x: 171, y: 171, c: 3)>
array([[[ nan, nan, nan],
[ nan, nan, nan],
...,
[ nan, nan, nan],
[ nan, nan, nan]],
[[ nan, nan, nan],
[ nan, nan, nan],
...,
[ nan, nan, nan],
[ nan, nan, nan]],
...,
[[ 91. , 27.333333, 63.333333],
[ 93.666667, 26.666667, 61. ],
...,
[ 125.666667, 25.333333, 70. ],
[ 143. , 48.333333, 69.333333]],
[[ 87.333333, 26.333333, 62.666667],
[ 91.333333, 25.333333, 58.666667],
...,
[ 149. , 39.333333, 79.666667],
[ 162.333333, 60.333333, 73.666667]]])
Dimensions without coordinates: x, y, c
それがいやな場合は、rolling(x=3, min_periods=1)
のように、最小のウィンドウ幅を指定します。min_periods=1
では、端の点は1ピクセルの移動平均の値(実質的には値そのまま)が入れられます。
plt.imshow(data.rolling(x=3, min_periods=1).median().astype('ubyte'))
<matplotlib.image.AxesImage at 0x7f6b7ed676a0>
よく見ると、この絵は平均的に下方向にシフトします。上の部分を拡大してみると
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(data.rolling(x=4, min_periods=1).median().astype('ubyte')[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
plt.subplot(1, 2, 2)
plt.imshow(small_image[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
[<matplotlib.lines.Line2D at 0x7f6b7ec430f0>]
このようなシフトを望まない場合は、center=True
を指定します。
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(data.rolling(x=4, min_periods=1, center=True).median().astype('ubyte')[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
plt.subplot(1, 2, 2)
plt.imshow(small_image[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
[<matplotlib.lines.Line2D at 0x7f6b7ebdc7b8>]