はじめに
統計的推論や機械学習において、情報幾何は確率分布の空間をリーマン幾何の枠組みで扱う強力な手法です。本記事では、1次元正規分布を例に、以下の概念をPython(PyTorch)で実装しながら解説します。
- フィッシャー情報行列
- 三次形式($T_{i,j,k}$)
- クリストフェル記号(Levi-Civita接続)
- $α$-接続のクリストフェル記号
- リーマン曲率テンソル、スカラー曲率
これらを用いることで、確率分布の幾何構造をより深く理解することができます。
1. フィッシャー情報行列の計算
フィッシャー情報行列は、統計モデルにおける情報量の尺度であり、統計推定の精度を評価する上で重要な役割を果たします。確率分布 $p(x | \theta)$ の対数尤度関数 $\log p(x | \theta)$ に対して、
I_{ij} = E\left[ \frac{\partial \log p}{\partial \theta_i} \frac{\partial \log p}{\partial \theta_j} \right]
で定義されます。Monte Carlo 法を用いて、この行列を近似計算します。
import torch
def compute_fisher_info(dist_constructor, params, num_samples=1000):
n_params = len(params)
fisher = torch.zeros(n_params, n_params, dtype=params[0].dtype)
for _ in range(num_samples):
dist = dist_constructor(*params)
x = dist.sample()
logp = dist.log_prob(x)
grads = torch.autograd.grad(logp, params, create_graph=True)
grads_vec = torch.stack([g.reshape(-1)[0] for g in grads])
fisher += grads_vec.unsqueeze(1) @ grads_vec.unsqueeze(0)
fisher /= num_samples
return fisher
2. 三次形式の計算
次に、三次形式 $T_{i,j,k}$ を Monte Carlo 法で近似します。
T_{i,j,k} = E\left[ \frac{\partial \log p}{\partial \theta_i} \frac{\partial \log p}{\partial \theta_j} \frac{\partial \log p}{\partial \theta_k} \right]
def compute_third_order_tensor(dist_constructor, params, num_samples=1000):
n_params = len(params)
T = torch.zeros(n_params, n_params, n_params, dtype=params[0].dtype)
for _ in range(num_samples):
dist = dist_constructor(*params)
x = dist.sample()
logp = dist.log_prob(x)
grads = torch.autograd.grad(logp, params, create_graph=True)
grads_vec = [g.reshape(-1)[0] for g in grads]
for i in range(n_params):
for j in range(n_params):
for k in range(n_params):
T[i, j, k] += grads_vec[i] * grads_vec[j] * grads_vec[k]
T /= num_samples
return T
3. Levi-Civita クリストフェル記号
フィッシャー計量の Levi-Civita 接続($\Gamma^{k}_{ij}$)を求めます。
\Gamma^k_{ij} = \frac{1}{2} \sum_l g^{kl} \left(\partial_i g_{jl} + \partial_j g_{il} - \partial_l g_{ij}\right)
def compute_christoffel_levi_civita(fisher, params):
n = fisher.shape[0]
fisher_inv = torch.inverse(fisher)
d_g = torch.zeros(n, n, n, dtype=fisher.dtype)
for i in range(n):
for j in range(n):
grad_list = torch.autograd.grad(fisher[i, j], params, retain_graph=True, allow_unused=True)
for k, grad in enumerate(grad_list):
if grad is not None:
d_g[k, i, j] = grad.reshape(-1)[0]
else:
d_g[k, i, j] = 0.0
Gamma = torch.zeros(n, n, n, dtype=fisher.dtype)
for k in range(n):
for i in range(n):
for j in range(n):
tmp = 0.0
for l in range(n):
tmp += fisher_inv[k, l] * (d_g[i, j, l] + d_g[j, i, l] - d_g[l, i, j])
Gamma[k, i, j] = 0.5 * tmp
return Gamma
4. α-接続のクリストフェル記号
指数接続 ($\alpha=1$) の場合、クリストフェル記号は以下のように修正されます。
\Gamma^{(\alpha)k}_{ij} = \Gamma^{(0)k}_{ij} + \frac{\alpha}{2} T^k_{ij}
def compute_christoffel_alpha(Gamma_lc, T, fisher_inv, alpha=1.0):
n = Gamma_lc.shape[0]
T_up = torch.zeros_like(T)
for l in range(n):
for i in range(n):
for j in range(n):
for k in range(n):
T_up[k, i, j] += fisher_inv[k, l] * T[l, i, j]
Gamma_alpha = Gamma_lc.clone()
Gamma_alpha += (alpha / 2.0) * T_up
return Gamma_alpha
5. リーマン曲率テンソルとスカラー曲率
最後に、$α$-接続のリーマン曲率テンソルおよびスカラー曲率を計算します。
def compute_scalar_curvature_alpha(Riemann, fisher):
Ricci = compute_ricci_tensor_alpha(Riemann)
fisher_inv = torch.inverse(fisher)
R_scalar = 0.0
n = fisher.shape[0]
for i in range(n):
for j in range(n):
R_scalar += fisher_inv[i, j] * Ricci[i, j]
return R_scalar
6. 実行結果
自然パラメータ (θ1, θ2) で、1次元正規分布の下記の通り計算しました。
- フィッシャー情報行列
- 三次形式($T_{i,j,k}$)
- クリストフェル記号(Levi-Civita接続)
- $α$-接続のクリストフェル記号
- リーマン曲率テンソル、スカラー曲率
Fisher Information Matrix:
tensor([[ 1.1310, -0.0842],
[-0.0842, 2.6330]], grad_fn=<DivBackward0>)
(α=1) Riemann Curvature Tensor:
tensor([[[[ 0.0000, 0.0252],
[-0.0252, 0.0000]],
[[ 0.0000, 0.4545],
[-0.4545, 0.0000]]],
[[[ 0.0000, -0.1389],
[ 0.1389, 0.0000]],
[[ 0.0000, -0.0160],
[ 0.0160, 0.0000]]]], grad_fn=<CopySlices>)
(α=1) Scalar Curvature: tensor(0.2972, grad_fn=<AddBackward0>)
おわりに
本記事では、1次元正規分布の情報幾何の数値解析を行い、フィッシャー情報行列からクリストフェル記号、リーマン曲率テンソル、スカラー曲率まで計算しました。情報幾何は深い理論ですが、実際にコードを書くことで理解が深まるので、ぜひ実装を試してみてください!