Python
pandas
PyData
xarray

xarray を使った 中央値フィルタ (メディアンフィルタ)

More than 1 year has passed since last update.

xarray を使った 中央値フィルタ(メディアンフィルタ)

移動平均などのフィルタを適用する目的で xarray を使ったので記録を残します。

しなければいけなかったことは、多次元アレイのある1方向にのみ移動平均(や移動中央値フィルタ)をかけるものです。

ここでは例として画像データを扱います。
本来は、画像データは画像データ用のライブラリを用いるべきですが、適当なデータがなかったためご了承ください。

import numpy as np
import xarray as xr
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
image = np.array(Image.open('lena.jpg'))
plt.imshow(image)
<matplotlib.image.AxesImage at 0x7f6b807a8898>

output_3_1.png

ここでは適当にサイズを落として、ソルトアンドペッパーノイズを加えます。

small_image = image[::3, ::3]
noisy_image = small_image.copy()
noisy_image[np.random.randint(0, noisy_image.shape[0], 1000), 
            np.random.randint(0, noisy_image.shape[1], 1000),
            np.random.randint(0, noisy_image.shape[2], 1000)] = 0
noisy_image[np.random.randint(0, noisy_image.shape[0], 1000), 
            np.random.randint(0, noisy_image.shape[1], 1000), 
            np.random.randint(0, noisy_image.shape[2], 1000)] = 256
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('original')
plt.imshow(small_image)
plt.subplot(1, 2, 2)
plt.title('noisy')
plt.imshow(noisy_image)
<matplotlib.image.AxesImage at 0x7f6b7eefb6d8>

output_6_1.png

xr.DataArray に格納します。縦、横、色方向の軸をそれぞれ x, y, c とします。

data = xr.DataArray(noisy_image, dims=['x', 'y', 'c'])

x方向に移動平均を取るには、rollingメソッドを使います。
メソッドのキーワードには対応する軸名と、窓サイズを指定します。

rolling = data.rolling(x=3)  # x 方向に3ピクセルの窓を考える。
rolling
DataArrayRolling [window->3,center->False,dim->x]

rolling は、mean, median, max, min, count などに対応しています。

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('mean')
plt.imshow(rolling.mean().astype('ubyte'))  # `imshow` で表示するために ubyte に変換しています。
plt.subplot(1, 2, 2)
plt.title('median')
plt.imshow(rolling.median().astype('ubyte'))
<matplotlib.image.AxesImage at 0x7f6b7edb7f60>

output_12_1.png

なお、よく見ると上の2ピクセルが黒くなっていると思います。これは、最初の2ピクセルに対しては移動平均ができないためです。そこには np.nan が入っています。

rolling.mean()
<xarray.DataArray (x: 171, y: 171, c: 3)>
array([[[        nan,         nan,         nan],
        [        nan,         nan,         nan],
        ..., 
        [        nan,         nan,         nan],
        [        nan,         nan,         nan]],

       [[        nan,         nan,         nan],
        [        nan,         nan,         nan],
        ..., 
        [        nan,         nan,         nan],
        [        nan,         nan,         nan]],

       ..., 
       [[  91.      ,   27.333333,   63.333333],
        [  93.666667,   26.666667,   61.      ],
        ..., 
        [ 125.666667,   25.333333,   70.      ],
        [ 143.      ,   48.333333,   69.333333]],

       [[  87.333333,   26.333333,   62.666667],
        [  91.333333,   25.333333,   58.666667],
        ..., 
        [ 149.      ,   39.333333,   79.666667],
        [ 162.333333,   60.333333,   73.666667]]])
Dimensions without coordinates: x, y, c

それがいやな場合は、rolling(x=3, min_periods=1) のように、最小のウィンドウ幅を指定します。min_periods=1 では、端の点は1ピクセルの移動平均の値(実質的には値そのまま)が入れられます。

plt.imshow(data.rolling(x=3, min_periods=1).median().astype('ubyte'))
<matplotlib.image.AxesImage at 0x7f6b7ed676a0>

output_16_1.png

よく見ると、この絵は平均的に下方向にシフトします。上の部分を拡大してみると

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(data.rolling(x=4, min_periods=1).median().astype('ubyte')[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
plt.subplot(1, 2, 2)
plt.imshow(small_image[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
[<matplotlib.lines.Line2D at 0x7f6b7ec430f0>]

output_18_1.png

このようなシフトを望まない場合は、center=True を指定します。

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(data.rolling(x=4, min_periods=1, center=True).median().astype('ubyte')[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
plt.subplot(1, 2, 2)
plt.imshow(small_image[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
[<matplotlib.lines.Line2D at 0x7f6b7ebdc7b8>]

output_20_1.png