3行で
- 類似度に基づくレコメンドやk近傍法など、データ同士の可能なペアすべてに対して計算を行いたい場面がある
- forループを書けばできるが、numpyなりtfなりの機能を活かしたい
- ブロードキャストを使って簡潔に記述できる
問題設定
$-1\leq x \leq 1$, $-1\leq y \leq 1$を満たすすべての$(x, y)$の集合を$A$とする。
$A$から任意の元を選んだ際、$A$に含まれる格子点と選んだ元とのL1距離を計算する。
なお、$(x_1, y_1)$と$(x_2,y_2)$のL1距離は以下で表される
d_1((x_1, y_1),(x_2,y_2)) = |x_1 - x_2|+|y_1 - y_2|
考え方
選んだ元とすべての格子点とのL1距離を計算し、距離が小さい順に5つ選ぶ。
L1距離の計算を一斉に行うため、次のようなnp.ndarray(もしくはtf.tensor)を考える。
lattice=np.array([[ 1, 1],
[ 1, 0],
[ 1, -1],
[ 0, 1],
[ 0, 0],
[ 0, -1],
[-1, 1],
[-1, 0],
[-1, -1]]) #shape = (9, 2)
今、仮に選んだ元が$(0.1,0.5)$だったとする。実は、L1距離の計算として次の記法は有効
data = np.array([0.1,0.5])
l1_dist = np.sum(np.abs(data-lattice),axis=1)
一見なんの変哲もない式に見えるが、data-lattice
の部分でshapeが異なるもの同士の引き算を行っている。
ここで自動的に2つのshapeがブロードキャストによって調整されている。
(参考: https://numpy.org/doc/stable/user/basics.broadcasting.html)
公式によれば、shapeの後ろから次元を比較し、片方の次元が1の場合もう片方に合わせてコピーで次元が増やされる。
今回の場合、data
のshapeは(2,)
、lattice
のshape
は(9,2)
なので、data
側の次元が調整され
array([[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5],
[0.1, 0.5]])
と見なされて引き算が実行された。
そしてnp.abs
でelement wiseに絶対値を計算し適切なaxis
に沿って和を取ればいい。
l1_dist
は次のようなshapeが(9,)
のnp.ndarray
になる
array([1.4, 1.4, 2.4, 0.6, 0.6, 1.6, 1.6, 1.6, 2.6])
バッチ処理への応用
L1距離を計算する対象の元を2個以上に増やしても同様の考え方ができる。
仮に、対象の元を2つとし、それぞれを$(0.1,0.5),(0.7,0.8)$とする。
今度はおそらくdata
が次の形式で供給されるだろう
data = np.array([[0.1, 0.5],
[0.7, 0.8]]) # shape = (2,2)
この場合、data-lattice
でブロードキャストは発生せずエラーになる。
shapeの後ろから次元を比較し、片方の次元が1の場合ではなくなったからである。
対処法はnp.expand_dims
で次元が1の軸を追加すればいい。
data = np.expand_dims(data,axis=1) # dataのshape = (2,1,2) latticeの(9,2)と比較してaxis=1のdataが9つ複製され、axis=0のlatticeが2つ複製される
l1_dist = np.sum(np.abs(data-lattice),axis=2) # (2,9,2)同士の引き算ののち、axis=2のsumが行われる。sumのaxisがexpandのせいで変わっているので注意
とすれば、l1_dist
は
array([[1.4, 1.4, 2.4, 0.6, 0.6, 1.6, 1.6, 1.6, 2.6],
[0.5, 1.1, 2.1, 0.9, 1.5, 2.5, 1.9, 2.5, 3.5]]) # shape = (2, 9)
となる。
ぼやき
まあコードは簡潔になるんだけど、暗黙的にshapeが変わるのって可読性をめちゃくちゃ下げてないかなと思う。
処理時間・可読性・簡潔さを兼ね備えたもっといいプラクティスがあれば知りたい。
参考
TensorFlow機械学習クックブック Pythonベースの活用レシピ60+
(挙げといてなんだが、この本を買うのはあまりおすすめしません・・・)