概要
PyTorch で与えた二つのテンソルの距離行列を計算する torch.cdist
が変な実装になっていると聞いて数学的にどうなっているか調べたら応用の利きそうな方法だったので一般化してみた.
torch.cdist
torch.cdist
(p=2
) の実装
PyTorch の torch.cdist
は与えられた二つの $n\times d$, $m\times d$ 次元テンソルについて, 以下のように $n \times m$ 通りの総当たりで $d$ 次元空間の $p$ 次ミンコフスキー距離 ($p = 2$ ならユークリッド距離) による距離行列を計算する関数である1:
$$
\begin{gather*}
D_p(x, y)_{ij} = \|x _{i}-y _{j}\|_p = \left(\sum _{k=0}^{d-1}|x _{ik}-y _{jk}|^p\right)^{\frac{1}{p}}, \\
0 \leq i \leq n-1, \quad 0 \leq j \leq m-1. \\
\end{gather*}
$$
絶対値の偶数乗は常に絶対値を取らず偶数乗した結果に一致するため, $p$ が偶数の時は $|x_{ik}-y_{jk}|^p$ の絶対値記号は通常の括弧に置き換えられることに注意しておこう.
この処理自体は SciPy や scikit-learn にも類似の関数が実装されており, それぞれ scipy.spatial.distance.cdist
, sklearn.metrics.pairwise_distances
で実行できるが, $p = 2$ の時 (すなわちユークリッド距離を計算する時), PyTorch のこれは $(\sum_k(x_{ik}-y_{jk})^2)^{\frac{1}{2}}$ を愚直に計算しない実装になっているらしい.
確かに以下の cpp のソースコードは一見してユークリッド距離に見えない妙な実装になっていて, 公式リファレンスにも「ユークリッド距離の計算には matrix multiplication approach を使う」と普通に実装してなさそうな記述がある.
Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) {
/** This function does the fist part of the euclidean distance calculation
* We divide it in two steps to simplify dealing with subgradients in the
* backward step */
Tensor x1_norm = x1.pow(2).sum(-1, true);
Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor x2_norm = x2.pow(2).sum(-1, true);
Tensor x2_pad = at::ones_like(x2_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor x1_ = at::cat({x1.mul(-2), std::move(x1_norm), std::move(x1_pad)}, -1);
Tensor x2_ = at::cat({x2, std::move(x2_pad), std::move(x2_norm)}, -1);
Tensor result = x1_.matmul(x2_.mT());
result.clamp_min_(0).sqrt_();
return result;
}
‘use_mm_for_euclid_dist_if_necessary’ - will use matrix multiplication approach to calculate euclidean distance (p = 2) if P > 25 or R > 25 ‘use_mm_for_euclid_dist’ - will always use matrix multiplication approach to calculate euclidean distance (p = 2) ‘donot_use_mm_for_euclid_dist’ - will never use matrix multiplication approach to calculate euclidean distance (p = 2) Default: use_mm_for_euclid_dist_if_necessary.
torch.cdist
(p=2
) の実装はユークリッド距離を計算するか
まずはこの torch.cdist
(p=2
) の実装が確かにユークリッド距離の計算となっていることを確かめよう.
ソースコードの当該部分を読み解くと以下のような実装になっていることが分かる:
$$
\begin{gather*}
D_2^\mathrm{torch}(x, y) = (\varphi(x)\psi(y)^t)^{\frac{1}{2}}, \\
\varphi(x) := \left[-2x, \sum _{k=0}^{d-1}x _{:k}^2, \mathbb{1} _{n\times1}\right] \in \mathbb{R}^{n\times(d+2)}, \\
\psi(y) := \left[y, \mathbb{1} _{m\times1}, \sum _{k=0}^{d-1}y _{:k}^2\right] \in \mathbb{R}^{m\times(d+2)}, \\
\end{gather*}
$$
ただし $x_{:k}$ は $x$ の第 $k$ 列を取り出した $n\times1$ テンソルであり ($y_{:k}$ も同様), $\mathbb{1}_{n\times 1}$ (または $\mathbb{1} _{m\times1}$) は全ての成分が $1$ の $n\times1$ (または $m\times1$) テンソルを表す.
また, $x _{:k}^2, y _{:k}^2$ における二乗及び最終的に $D_2^\mathrm{torch}(x, y)$ を求める時の $1/2$ 乗はテンソルの成分毎の演算である.
$\varphi(x), \psi(y)$ の定義でテンソルを並べているのは結合を表しており, ブロック行列と同様の記法だと思ってもらいたい.
この $D_2^\mathrm{torch}(x, y)$ の $(i, j)$ 成分の二乗を計算してみると, $x, y$ の第 $i$ 行, 第 $j$ 行をそれぞれ $x_i, y_j$ と書いて, 内積を $\langle\cdot, \cdot\rangle$ で表せば,
$$
\begin{align*}
D_2^\mathrm{torch}(x, y)_{ij}^2 & = (\varphi(x)\psi(y)^t) _{ij} \\
& =\left\langle\left[-2x_i, \sum _{k=0}^{d-1}x _{ik}^2, 1\right], \left[y_j, 1, \sum _{k=0}^{d-1}y _{jk}^2\right]\right\rangle \\
& = -2\sum _{k=0}^{d-1}x _{ik}y _{jk}+\sum _{k=0}^{d-1}x _{ik}^2+\sum _{k=0}^{d-1}y _{jk}^2 \\
& = \sum _{k=0}^{d-1}(x _{ik}^2-2x _{ik}y _{jk}+y _{jk}^2) \\
& = \sum _{k=0}^{d-1}(x _{ik}-y _{jk})^2 = D_2(x, y) _{ij}^2, \\
\end{align*}
$$
となり, これは $x$ の第 $i$ 行ベクトルと $y$ の第 $j$ 行ベクトルのユークリッド距離の二乗になっていることが分かる.
従って, この torch.dist
(p=2
) の実装は確かにユークリッド距離の計算となっていることが確かめられた.
自動微分における利点
後に見るようにこの実装は高速であるという特徴もあるが, 自動微分を考える上でも利点がある.
$x, y$ の $2$ 変数関数を $x$ 変数行列と $y$ 変数行列の積に分解しているため,
$$
\frac{\partial}{\partial x}(\varphi(x)\psi(y)^t) = \frac{\partial\varphi}{\partial x}(x)\psi(y)^t, \quad \frac{\partial}{\partial y}(\varphi(x)\psi(y)^t) = \varphi(x)\frac{\partial\psi}{\partial y}(y)^t,
$$
と偏微分を計算でき, forward で計算済みのはずの $\varphi(x), \psi(y)$ を再利用できる点である.
一般化
$n\times d$ 次元テンソル $x$ と $m\times d$ 次元テンソル $y$ を引数とする関数 $F: \mathbb{R}^{n\times d}\times\mathbb{R}^{m\times d} \rightarrow \mathbb{R}^{n\times m}$ で, $(i, j)$ 成分が $2$ 変数関数 $f: \mathbb{R}^d\times\mathbb{R}^d \rightarrow \mathbb{R}$ によって,
$$
F(x, y)_{ij} = f(x_i, y_j),
$$
と定まるものを考える.
以降, $\varphi$ 及び $\psi$ は節によって定義 (特に定義域と値域) が変わることに注意しておこう.
ただし, それぞれ $x, y$ を特徴量変換する関数を表す点では共通している.
各成分が内積で書ける場合
torch.cdist
の実装において $(i, j)$ 成分を計算する時,
$$
D_2^\mathrm{torch}(x, y)_{ij}^2 =\left\langle\left[-2x_i, \sum _{k=0}^{d-1}x _{ik}^2, 1\right], \left[y_j, 1, \sum _{k=0}^{d-1}y _{jk}^2\right]\right\rangle,
$$
と内積の形に変形されており, ここからこの手法を一般化できる.
つまり, 二つの関数 $\varphi, \psi: \mathbb{R}^d \rightarrow \mathbb{R}^l$ があって $f(x_i, y_j) = \langle\varphi(x_i), \psi(y_j)\rangle$ と書ける時,
$$
\begin{align*}
\varphi(x) & = \varphi\left(\begin{bmatrix} x_0 \\ x_1 \\ \vdots \\ x_{n-1} \end{bmatrix}\right) := \begin{bmatrix} \varphi(x_0) \\ \varphi(x_1) \\ \vdots \\ \varphi(x_{n-1}) \end{bmatrix} \in \mathbb{R}^{n\times l}, \\
\psi(y) & = \psi\left(\begin{bmatrix} y_0 \\ y_1 \\ \vdots \\ y_{m-1} \end{bmatrix}\right) := \begin{bmatrix} \psi(y_0) \\ \psi(y_1) \\ \vdots \\ \psi(y_{m-1}) \end{bmatrix} \in \mathbb{R}^{m\times l}, \\
\end{align*}
$$
という所謂ベクトル化 (vectorization) による記法を使うと,
$$
\begin{align*}
F(x, y) & = \begin{bmatrix} \langle\varphi(x_0), \psi(y_0)\rangle & \langle\varphi(x_0), \psi(y_1)\rangle & \dots & \langle\varphi(x_0), \psi(y_{m-1})\rangle \\ \langle\varphi(x_1), \psi(y_0)\rangle & \langle\varphi(x_1), \psi(y_1)\rangle & \dots & \langle\varphi(x_1), \psi(y_{m-1})\rangle \\ \vdots & \vdots & \ddots & \vdots \\ \langle\varphi(x_{n-1}), \psi(y_0)\rangle & \langle\varphi(x_{n-1}), \psi(y_1)\rangle & \dots & \langle\varphi(x_{n-1}), \psi(y_{m-1})\rangle \end{bmatrix} \\
& = \begin{bmatrix} \varphi(x_0) \\ \varphi(x_1) \\ \vdots \\ \varphi(x_{n-1}) \end{bmatrix}\begin{bmatrix} \psi(y_0) \\ \psi(y_1) \\ \vdots \\ \psi(y_{m-1}) \end{bmatrix}^t = \varphi(x)\psi(y)^t, \\
\end{align*}
$$
と変形でき, $F(x, y)$ は torch.cdist
の場合と同様の形で $x$ 変数行列と $y$ 変数行列の積に分解される.
各成分が成分毎の内積の和で書ける場合
前節の一般化は幅広いケースを大雑把にまとめているので, もう少し torch.cdist
に寄せた話をしよう.
ユークリッド距離の計算中に現れる差の二乗和 $\sum_{k=0}^{d-1}(x_{ik}-y_{jk})^2$ の各項は $1$ 次元ベクトル (つまりスカラー) の差の二乗和と見なすことができ, それ自体が内積で書けることが分かる.
つまり,
$$
\sum_{k=0}^{d-1}(x_{ik}-y_{jk})^2 = \sum_{k=0}^{d-1}\langle[-2x_{ik}, x_{ik}^2, 1], [y_{jk}, 1, y_{jk}^2]\rangle,
$$
という $d+2$ 次元ベクトル同士の内積の和で書けて, $g: \mathbb{R}\times\mathbb{R} \rightarrow \mathbb{R}$ 及び $\varphi, \psi: \mathbb{R} \rightarrow \mathbb{R}^l$ により,
$$
F(x, y)_{ij} = f(x_i, y_j) := \sum _{k=0}^{d-1}g(x _{ik}, y _{jk}) = \sum _{k=0}^{d-1}\langle\varphi(x _{ik}), \psi(y _{jk})\rangle,
$$
と書ける場合に一般化できる.
これはベクトル化によって,
$$
F(x, y) = \sum_{k=0}^{d-1}g(x_{:k}, y_{:k}) = \sum_{k=0}^{d-1}\varphi(x_{:k})\psi(y_{:k})^t,
$$
という行列積の和に変形でき, さらに,
$$
\begin{align*}
\varphi(x) & := [\varphi(x_{:0}), \varphi(x_{:1}), \dots, \varphi(x_{:d-1})] \in \mathbb{R}^{n\times ld}, \\
\psi(y) & := [\psi(y_{:0}), \psi(y_{:1}), \dots, \psi(y_{:d-1})] \in \mathbb{R}^{m\times ld}, \\
\end{align*}
$$
と書けば, ブロック行列の公式から,
$$
F(x, y) = \sum_{k=0}^{d-1}\varphi(x_{:k})\psi(y_{:k})^t = \varphi(x)\psi(y)^t,
$$
とも書き換えられる.
各成分が多項式関数の場合
前節・前々節の手法が使えるか判別が容易な例に, $f(x_i, y_j)$ が $x_i, y_j$ の各成分による多項式関数となっている例がある.
ユークリッド距離の二乗 $\sum_{k=0}^{d-1}(x_{ik}^2-x_{ik}y_{jk}+y_{jk}^2)$ もその例の一つである.
多項式関数は非負整数列 $s_0, s_1, \dots, s_{d-1}, t_0, t_1, \dots, t_{d-1}$ と定数 $\alpha_{p_0\dots p_{d-1}q_0\dots q_{d-1}}$ ($0 \leq p_k \leq s_{k-1}, 0 \leq q_k \leq t_{k-1}$) によって,
$$
f(x_i, y_j) = \sum_{p_0=0}^{s_0-1}\dots\sum_{p_{d-1}=0}^{s_{d-1}-1}\sum_{q_0=0}^{t_0-1}\dots\sum_{q_{d-1}=0}^{t_{d-1}-1}\alpha_{p_0\dots p_{d-1}q_0\dots q_{d-1}}\prod_{k=0}^{d-1}x_{ik}^{p_k}y_{jk}^{q_k},
$$
と書かれる関数である.
この時 $\varphi$ 及び $\psi$ を $p_0, \dots, p_{d-1}, q_0, \dots, q_{d-1}$ で $2d$ 重にインデクス付けて,
$$
\begin{cases}
\displaystyle \varphi(x_i)_{p_0\dots p _{d-1}q_0\dots q _{d-1}} := \alpha _{p_0\dots p _{d-1}q_0\dots q _{d-1}}\prod _{k=0}^{d-1}x _{ik}^{p_k}, \\
\displaystyle \psi(y_j) _{p_0\dots p _{d-1}q_0\dots q _{d-1}} := \prod _{k=0}^{d-1}y _{jk}^{q_k}, \\
\end{cases}
$$
で定めると $f(x_i, y_j) = \langle\varphi(x_i), \psi(y_j)\rangle$ と書くことができるようになる.
この時, 同次の項でまとめれば $\varphi(x_i), \psi(y_j)$ の次元削減になる.
つまり,
$$
\begin{cases}
\displaystyle \varphi(x_i)_{q_0\dots q _{d-1}} := \sum _{p_0=0}^{s_0-1}\dots\sum _{p _{d-1}=0}^{s _{d-1}-1}\alpha _{p_0\dots p _{d-1}q_0\dots q _{d-1}}\prod _{k=0}^{d-1}x _{ik}^{p_k}, \\
\displaystyle \psi(y_j) _{q_0\dots q _{d-1}} := \prod _{k=0}^{d-1}y _{jk}^{q_k}, \\
\end{cases}
$$
と取ることができ, $\alpha_{p_0\dots p_{d-1}q_0\dots q_{d-1}}$ のうち $q_0, q_1, \dots, q_{d-1}$ に依存しないものがあれば, 対応する $\prod_{k=0}^{d-1}x_{ik}^{p_k}$ の項も同様にまとめることができる.
ユークリッド距離の二乗の場合は全成分の最大次数が $2$ なので $s_0 = \dots = s_{d-1} = t_0 = \dots = t_{d-1} = 3$ であり,
$$
\begin{cases}
\text{ただ一つの $p_k$ または $q_k$ が $2$ でその他が 0} & \Longrightarrow & \alpha _{p_0\dots p _{d-1}q_0\dots q _{d-1}} = 1, \\
\text{ただ一組の $p_k$ 及び $q_k$ が共に $1$ でその他が 0} & \Longrightarrow & \alpha _{p_0\dots p _{d-1}q_0\dots q _{d-1}} = -2, \\
それ以外 & \Longrightarrow & \alpha _{p_0\dots p _{d-1}q_0\dots q _{d-1}} = 0, \\
\end{cases}
$$
と場合分けされるため, $\alpha _{p_0\dots p _{d-1}q_0\dots q _{d-1}} = 0$ の項を取り除いて同じ係数の項でまとめれば,
$$
\varphi(x_i)_h = \begin{cases} -2x _{ih} & h \lt d, \\ \displaystyle \sum _{k=0}^{d-1}x _{ik}^2 & h = d, \\ 1 & h = d+1, \end{cases} \qquad \psi(y_j)_h = \begin{cases} y _{jh} & h \lt d, \\ 1 & h = d, \\ \displaystyle \sum _{k=0}^{d-1}y _{jk}^2 & h = d+1, \end{cases} \\
$$
と整理されて torch.cdist
で実装されている形式が得られる.
ミンコフスキー距離への再輸入
ミンコフスキー距離は差の絶対値の $p$ 乗を含むため, $p$ が奇数の時は一般に内積で書くことはできない2 (脚注含め演習).
しかし $p$ が偶数ならば冒頭で述べたように絶対値は省いて考えることが出来るため, ミンコフスキー距離の $p$ 乗にここまでの手法を適用できる.
$p$ が偶数の時, $x_{ik}-y_{jk}$ (の絶対値) の $p$ 乗は二項定理により,
$$
(x_{ik}-y_{jk})^p = \sum_{h=0}^p(-1)^hC^p_hx_{ik}^hy_{jk}^{p-h}, \quad C^p_h := \frac{p!}{h!(p-h)!},
$$
と書かれる多項式関数となる.
$!$ は階乗の記号で, $C^p_h$ は $p, h$ で決まる定数となることに注意.
よってその $k$ に関する総和は,
$$
\sum_{k=0}^{d-1}\sum_{h=0}^p(-1)^hC^p_hx_{ik}^hy_{jk}^{p-h} = \langle[(-1)^hC^p_hx_i^h]_{h=0}^p, [y _i^{p-h}] _{h=0}^p\rangle,
$$
という内積で書ける.
$[(-1)^hC^p_hx_i^h]_{h=0}^p, [y_j^{p-h}] _{h=0}^p$ はそれぞれ $(-1)^hC^p_hx_i^h, y_j^{p-h}$ を $h=0$ から $h=p$ まで並べた (ブロック) ベクトルを表し, 再掲となるが $x_i^h$ 及び $y_i^{p-h}$ は行ベクトルの成分ごとの冪乗である.
これは $p=2$ の時と同様に $h=0, p$ の時に定数項が現れる成分をまとめることができ,
$$
\sum_{k=0}^{d-1}\sum_{h=0}^p(-1)^hC^p_hx_{ik}^hy_{jk}^{p-h} = \left\langle\left[[(-1)^hC^p_hx_i^h]_{h=1}^{p-1}, \sum _{k=0}^{d-1}x _{ik}^p, 1\right], \left[[y _j^{p-h}] _{h=1}^{p-1}, 1, \sum _{k=0}^{d-1}y _{jk}^p\right]\right\rangle,
$$
と変形される.
実験
偶数の $p$ 次ミンコフスキー距離を素直な方法と本記事の方法の二通りで実装してライブラリに実装済みの関数と比較してみる.
PyTorch と NumPy でそれぞれ試していて, どちらも以下の実装となっている:
- 自前実装 1: 素直な方法 → $f(x_i, y_j) = \left(\sum_{k=0}^{d-1}(x_{ik}-y_{jk})^p\right)^\frac{1}{p}$ をそのまま実装,
- 自前実装 2: 本記事の方法 → $f(x_i, y_j) = \langle\varphi(x_i), \psi(y_j)\rangle^\frac{1}{p}$ と内積に変形して実装.
これまでの解説からすると $1/p$ 乗が邪魔な感じもする3が, 比較したいライブラリの関数から取り払えない要素なので仕方なく合わせている.
二項係数 $C^p_h$ の計算方法はこのぐらいの計算量なら scipy.special.comb
を for 文で回してもいいが, この記事の「nを固定してrの変化を求める場合」の節辺りを NumPy で高速化している.
PyTorch 編
CPU 環境なので GPU だと速度関係が変わるかもしれない.
import torch
import numpy as np
# from scipy.special import comb
rng = np.random.default_rng()
# データ生成, torch のテンソル化
n, m = 200, 500
d = 100
x = rng.normal(size=[n, d])
y = rng.normal(size=[m, d])
x_tc = torch.from_numpy(x)
y_tc = torch.from_numpy(y)
# 自前実装 1 # 素直な方法
def cdist1(x, y, p=2):
return torch.pow(torch.sum(torch.pow(torch.unsqueeze(x, 1)-y, p), dim=-1), 1/p)
# 自前実装 2 # 本記事の方法
def cdist2(x, y, p=2):
coefs = torch.from_numpy(np.cumprod(np.hstack([-np.arange(p, 1, -1)/np.arange(1, p)])))
exp = torch.from_numpy(np.arange(1, p).astype(float))
x_p = torch.sum(torch.pow(x, p), dim=-1, keepdims=True)
x_h = torch.reshape(coefs*torch.pow(torch.unsqueeze(x, -1), exp), [x.shape[0], -1])
x_aug = torch.cat([x_h, x_p, torch.ones_like(x_p)], dim=-1)
y_p = torch.sum(torch.pow(y, p), dim=-1, keepdims=True)
y_h = torch.reshape(torch.pow(torch.unsqueeze(y, -1), p-exp), [y.shape[0], -1])
y_aug = torch.cat([y_h, torch.ones_like(y_p), y_p], dim=-1)
return torch.pow(x_aug@y_aug.T, 1/p)
# p = 2, 4, 6, 8, 10 で比較検証
# torch.cdist との max abs error と各実行時間を比較
mae = lambda a, b: torch.max(torch.abs(a-b))
for p in range(2, 11, 2):
print('p = {}'.format(p))
cd1 = cdist1(x_tc, y_tc, p)
cd2 = cdist2(x_tc, y_tc, p)
cdt = torch.cdist(x_tc, y_tc, p)
print('max abs error of cdist1 and torch.cdist: {}'.format(mae(cd1, cdt)))
print('max abs error of cdist2 and torch.cdist: {}'.format(mae(cd2, cdt)))
print('exe. time of cdist1 : ', end='')
%timeit -n 10 -r 10 cdist1(x_tc, y_tc, p)
print('exe. time of cdist2 : ', end='')
%timeit -n 10 -r 10 cdist2(x_tc, y_tc, p)
print('exe. time of torch.cdist: ', end='')
%timeit -n 10 -r 10 torch.cdist(x_tc, y_tc, p)
p = 2
max abs error of cdist1 and torch.cdist: 3.552713678800501e-15
max abs error of cdist2 and torch.cdist: 0.0
exe. time of cdist1 : 52.9 ms ± 5.46 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2 : 14.2 ms ± 2.97 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of torch.cdist: 14.7 ms ± 2.96 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 4
max abs error of cdist1 and torch.cdist: 2.6645352591003757e-15
max abs error of cdist2 and torch.cdist: 3.552713678800501e-15
exe. time of cdist1 : 168 ms ± 7.27 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2 : 55.9 ms ± 8.02 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of torch.cdist: 115 ms ± 9.12 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 6
max abs error of cdist1 and torch.cdist: 1.7763568394002505e-15
max abs error of cdist2 and torch.cdist: 3.552713678800501e-15
exe. time of cdist1 : 187 ms ± 12 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2 : 84.9 ms ± 9.34 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of torch.cdist: 116 ms ± 5.67 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 8
max abs error of cdist1 and torch.cdist: 1.7763568394002505e-15
max abs error of cdist2 and torch.cdist: 5.329070518200751e-15
exe. time of cdist1 : 179 ms ± 10.6 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2 : 110 ms ± 10.9 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of torch.cdist: 114 ms ± 4.43 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 10
max abs error of cdist1 and torch.cdist: 1.7763568394002505e-15
max abs error of cdist2 and torch.cdist: 1.4654943925052066e-14
exe. time of cdist1 : 170 ms ± 12.1 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2 : 148 ms ± 9.07 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of torch.cdist: 107 ms ± 6.75 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
max abs error はいずれも問題なし.
速度について p=2
で同じアルゴリズムの自前実装 2 と torch.cdist
が同等なのと p
が増えるに従って扱う行列サイズが増えるので自前実装 2 が遅くなっていっているのは予想通り.
よく分からないのが自前実装 1 と torch.cdist
が p=2
だけ速いところと p>=4
で自前実装 1 より torch.cdist
が速いところ.
p
が小さい時に自前実装 1 が自前実装 2 より速いのもよく分からない部分だが, たぶんハードウェア的に行列の積が最適化されているからじゃないかと思っている.
NumPy 編
比較対象の scipy.spatial.distance.cdist
は自前実装 1 と同じアルゴリズムになっている.
**
演算子の代わり np.power
を使うと p=2
で他の場合と同じぐらい遅くなるので注意.
# PyTorch の代わり NumPy で実装してみた版
# torch.cdist は SciPy の類似関数で代用 (実装アルゴリズムは素直な方法)
import numpy as np
from scipy.spatial import distance
# from scipy.special import comb
rng = np.random.default_rng()
# データ生成
n, m = 200, 500
d = 100
x = rng.normal(size=[n, d])
y = rng.normal(size=[m, d])
# 自前実装 1 # 素直な方法
def cdist1_np(x, y, p=2):
return np.sum((np.expand_dims(x, 1)-y)**p, axis=-1)**(1/p)
# 自前実装 2 # 本記事の方法
def cdist2_np(x, y, p=2):
coefs = np.cumprod(np.hstack([-np.arange(p, 1, -1)/np.arange(1, p)]))
exp = np.arange(1, p)
x_p = np.sum(x**p, axis=-1, keepdims=True)
x_h = np.reshape(coefs*np.expand_dims(x, -1)**exp, [x.shape[0], -1])
x_aug = np.concatenate([x_h, x_p, np.ones_like(x_p)], axis=-1)
y_p = np.sum(y**p, axis=-1, keepdims=True)
y_h = np.reshape(np.expand_dims(y, -1)**(p-exp), [y.shape[0], -1])
y_aug = np.concatenate([y_h, np.ones_like(y_p), y_p], axis=-1)
return (x_aug@y_aug.T)**(1/p)
# p = 2, 4, 6, 8, 10 で比較検証
# scipy.spatial.distance.cdist との max abs error と各実行時間を比較
mae = lambda a, b: np.max(np.abs(a-b))
for p in range(2, 11, 2):
print('p = {}'.format(p))
cd1 = cdist1_np(x, y, p)
cd2 = cdist2_np(x, y, p)
cds = distance.cdist(x, y, 'minkowski', p=p)
print('max abs error of cdist1_np and scipy.spatial.distance.cdist: {}'.format(mae(cd1, cds)))
print('max abs error of cdist2_np and scipy.spatial.distance.cdist: {}'.format(mae(cd2, cds)))
print('exe. time of cdist1_np : ', end='')
%timeit -n 10 -r 10 cdist1_np(x, y, p)
print('exe. time of cdist2_np : ', end='')
%timeit -n 10 -r 10 cdist2_np(x, y, p)
print('exe. time of scipy.spatial.distance.cdist: ', end='')
%timeit -n 10 -r 10 distance.cdist(x, y, 'minkowski', p=p)
p = 2
max abs error of cdist1_np and scipy.spatial.distance.cdist: 1.0658141036401503e-14
max abs error of cdist2_np and scipy.spatial.distance.cdist: 1.0658141036401503e-14
exe. time of cdist1_np : 79.6 ms ± 7.6 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2_np : 1.03 ms ± 132 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of scipy.spatial.distance.cdist: 7.27 ms ± 1.62 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 4
max abs error of cdist1_np and scipy.spatial.distance.cdist: 2.6645352591003757e-15
max abs error of cdist2_np and scipy.spatial.distance.cdist: 3.552713678800501e-15
exe. time of cdist1_np : 485 ms ± 17 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2_np : 15.2 ms ± 1.32 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of scipy.spatial.distance.cdist: 392 ms ± 16.8 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 6
max abs error of cdist1_np and scipy.spatial.distance.cdist: 1.7763568394002505e-15
max abs error of cdist2_np and scipy.spatial.distance.cdist: 3.552713678800501e-15
exe. time of cdist1_np : 489 ms ± 17.2 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2_np : 23.4 ms ± 4.48 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of scipy.spatial.distance.cdist: 378 ms ± 25.1 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 8
max abs error of cdist1_np and scipy.spatial.distance.cdist: 8.881784197001252e-16
max abs error of cdist2_np and scipy.spatial.distance.cdist: 7.993605777301127e-15
exe. time of cdist1_np : 462 ms ± 23.7 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2_np : 29.2 ms ± 2.34 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of scipy.spatial.distance.cdist: 351 ms ± 20.9 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
p = 10
max abs error of cdist1_np and scipy.spatial.distance.cdist: 8.881784197001252e-16
max abs error of cdist2_np and scipy.spatial.distance.cdist: 2.0872192862952943e-14
exe. time of cdist1_np : 443 ms ± 23.1 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of cdist2_np : 35.4 ms ± 3.86 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
exe. time of scipy.spatial.distance.cdist: 334 ms ± 9.3 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
PyTorch 編と似たような感想だが, scipy.spatial.distance.cdist
が p=2
で (p>=4
より) 爆速なのは BLAS からユークリッド距離を計算する関数を取って来ているかららしい.
こちらは自前実装 2 も爆速で, この計算時間の増加の感じだと p=100
ぐらいまでは高速化の効果があるんじゃないかという印象4.