Posted at

NumPy で行列の差分計算

極めてニッチで一体誰の参考になるのかという小ネタだけど、自分用のメモとして。


前提

np で NumPy を使えるものとする。

import numpy as np


やりたいこと

以下のような要素のリスト (例: 1から6までの整数) があったとする。

arr = np.arange(1, 7)  # array([1, 2, 3, 4, 5, 6])

このリストから2つの要素を選ぶ全ての組み合わせについて、何らかの計算 (例: 2要素の積) をした行列を求めたとする。

X, Y = np.meshgrid(arr, arr)

matrix = X * Y
# array([[ 1, 2, 3, 4, 5, 6],
# [ 2, 4, 6, 8, 10, 12],
# [ 3, 6, 9, 12, 15, 18],
# [ 4, 8, 12, 16, 20, 24],
# [ 5, 10, 15, 20, 25, 30],
# [ 6, 12, 18, 24, 30, 36]])

(補足) 上記のコードは、下記のコードでも同じ結果が得られる。

prod = lambda X, Y: X * Y

matrix = prod(*np.meshgrid(arr, arr))
# 上記と同じ行列が得られる

ここで、リストに対して要素が追加されることになった、とする。

arr_new = np.arange(7, 10)  # array([7, 8, 9])

もともとの要素 arr と追加要素 arr_new を結合した新しいリストについて、先ほど計算したような行列を再び求めたい。

愚直に書くとこうなる。

arr_all = np.concatenate([arr, arr_new])

matrix_new = prod(*np.meshgrid(arr_all, arr_all))
# array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [ 2, 4, 6, 8, 10, 12, 14, 16, 18],
# [ 3, 6, 9, 12, 15, 18, 21, 24, 27],
# [ 4, 8, 12, 16, 20, 24, 28, 32, 36],
# [ 5, 10, 15, 20, 25, 30, 35, 40, 45],
# [ 6, 12, 18, 24, 30, 36, 42, 48, 54],
# [ 7, 14, 21, 28, 35, 42, 49, 56, 63],
# [ 8, 16, 24, 32, 40, 48, 56, 64, 72],
# [ 9, 18, 27, 36, 45, 54, 63, 72, 81]])

しかし、もともとの要素 arr に関する計算済みの行列 matrix があるので、この行列を利用して新しい要素に関する拡張部分だけを計算することにして、計算量を抑えたい。

この差分計算を NumPy で書きたい。


解法

このような関数を作ると実現できる。

def extend_matrix(matrix_base, arr_base, arr_new, func):

part0 = func(*np.meshgrid(arr, arr_new))
part1 = func(*np.meshgrid(arr_new, arr_new))
return np.concatenate([
np.concatenate([matrix_base, part0], axis=0),
np.concatenate([part0.T, part1], axis=0),
], axis=1)

行列の新しく増える部分だけを計算して、既存の行列と結合している。

使うときは、こう。

matrix_new = extend_matrix(matrix, arr, arr_new, prod)

# array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [ 2, 4, 6, 8, 10, 12, 14, 16, 18],
# [ 3, 6, 9, 12, 15, 18, 21, 24, 27],
# [ 4, 8, 12, 16, 20, 24, 28, 32, 36],
# [ 5, 10, 15, 20, 25, 30, 35, 40, 45],
# [ 6, 12, 18, 24, 30, 36, 42, 48, 54],
# [ 7, 14, 21, 28, 35, 42, 49, 56, 63],
# [ 8, 16, 24, 32, 40, 48, 56, 64, 72],
# [ 9, 18, 27, 36, 45, 54, 63, 72, 81]])

追加要素についての差分計算だけで全体の行列を求めることができた。