12
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

情報幾何の数値解析

Last updated at Posted at 2025-02-28

はじめに

統計的推論や機械学習において、情報幾何は確率分布の空間をリーマン幾何の枠組みで扱う強力な手法です。本記事では、1次元正規分布を例に、以下の概念をPython(PyTorch)で実装しながら解説します。

  • フィッシャー情報行列
  • 三次形式($T_{i,j,k}$)
  • クリストフェル記号(Levi-Civita接続)
  • $α$-接続のクリストフェル記号
  • リーマン曲率テンソル、スカラー曲率

これらを用いることで、確率分布の幾何構造をより深く理解することができます。

1. フィッシャー情報行列の計算

フィッシャー情報行列は、統計モデルにおける情報量の尺度であり、統計推定の精度を評価する上で重要な役割を果たします。確率分布 $p(x | \theta)$ の対数尤度関数 $\log p(x | \theta)$ に対して、

I_{ij} = E\left[ \frac{\partial \log p}{\partial \theta_i} \frac{\partial \log p}{\partial \theta_j} \right]

で定義されます。Monte Carlo 法を用いて、この行列を近似計算します。

import torch

def compute_fisher_info(dist_constructor, params, num_samples=1000):
    n_params = len(params)
    fisher = torch.zeros(n_params, n_params, dtype=params[0].dtype)
    
    for _ in range(num_samples):
        dist = dist_constructor(*params)
        x = dist.sample()
        logp = dist.log_prob(x)
        
        grads = torch.autograd.grad(logp, params, create_graph=True)
        grads_vec = torch.stack([g.reshape(-1)[0] for g in grads])
        fisher += grads_vec.unsqueeze(1) @ grads_vec.unsqueeze(0)
    
    fisher /= num_samples
    return fisher

2. 三次形式の計算

次に、三次形式 $T_{i,j,k}$ を Monte Carlo 法で近似します。

T_{i,j,k} = E\left[ \frac{\partial \log p}{\partial \theta_i} \frac{\partial \log p}{\partial \theta_j} \frac{\partial \log p}{\partial \theta_k} \right]
def compute_third_order_tensor(dist_constructor, params, num_samples=1000):
    n_params = len(params)
    T = torch.zeros(n_params, n_params, n_params, dtype=params[0].dtype)
    
    for _ in range(num_samples):
        dist = dist_constructor(*params)
        x = dist.sample()
        logp = dist.log_prob(x)
        
        grads = torch.autograd.grad(logp, params, create_graph=True)
        grads_vec = [g.reshape(-1)[0] for g in grads]
        
        for i in range(n_params):
            for j in range(n_params):
                for k in range(n_params):
                    T[i, j, k] += grads_vec[i] * grads_vec[j] * grads_vec[k]
    
    T /= num_samples
    return T

3. Levi-Civita クリストフェル記号

フィッシャー計量の Levi-Civita 接続($\Gamma^{k}_{ij}$)を求めます。

\Gamma^k_{ij} = \frac{1}{2} \sum_l g^{kl} \left(\partial_i g_{jl} + \partial_j g_{il} - \partial_l g_{ij}\right)
def compute_christoffel_levi_civita(fisher, params):
    n = fisher.shape[0]
    fisher_inv = torch.inverse(fisher)
    d_g = torch.zeros(n, n, n, dtype=fisher.dtype)
    
    for i in range(n):
        for j in range(n):
            grad_list = torch.autograd.grad(fisher[i, j], params, retain_graph=True, allow_unused=True)
            for k, grad in enumerate(grad_list):
                if grad is not None:
                    d_g[k, i, j] = grad.reshape(-1)[0]
                else:
                    d_g[k, i, j] = 0.0
    
    Gamma = torch.zeros(n, n, n, dtype=fisher.dtype)
    for k in range(n):
        for i in range(n):
            for j in range(n):
                tmp = 0.0
                for l in range(n):
                    tmp += fisher_inv[k, l] * (d_g[i, j, l] + d_g[j, i, l] - d_g[l, i, j])
                Gamma[k, i, j] = 0.5 * tmp
    
    return Gamma

4. α-接続のクリストフェル記号

指数接続 ($\alpha=1$) の場合、クリストフェル記号は以下のように修正されます。

\Gamma^{(\alpha)k}_{ij} = \Gamma^{(0)k}_{ij} + \frac{\alpha}{2} T^k_{ij}
def compute_christoffel_alpha(Gamma_lc, T, fisher_inv, alpha=1.0):
    n = Gamma_lc.shape[0]
    T_up = torch.zeros_like(T)
    for l in range(n):
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    T_up[k, i, j] += fisher_inv[k, l] * T[l, i, j]
    
    Gamma_alpha = Gamma_lc.clone()
    Gamma_alpha += (alpha / 2.0) * T_up
    return Gamma_alpha

5. リーマン曲率テンソルとスカラー曲率

最後に、$α$-接続のリーマン曲率テンソルおよびスカラー曲率を計算します。

def compute_scalar_curvature_alpha(Riemann, fisher):
    Ricci = compute_ricci_tensor_alpha(Riemann)
    fisher_inv = torch.inverse(fisher)
    
    R_scalar = 0.0
    n = fisher.shape[0]
    for i in range(n):
        for j in range(n):
            R_scalar += fisher_inv[i, j] * Ricci[i, j]
    return R_scalar

6. 実行結果

自然パラメータ (θ1, θ2) で、1次元正規分布の下記の通り計算しました。

  • フィッシャー情報行列
  • 三次形式($T_{i,j,k}$)
  • クリストフェル記号(Levi-Civita接続)
  • $α$-接続のクリストフェル記号
  • リーマン曲率テンソル、スカラー曲率

Fisher Information Matrix:
tensor([[ 1.1310, -0.0842],
        [-0.0842,  2.6330]], grad_fn=<DivBackward0>)

(α=1) Riemann Curvature Tensor:
tensor([[[[ 0.0000,  0.0252],
          [-0.0252,  0.0000]],

         [[ 0.0000,  0.4545],
          [-0.4545,  0.0000]]],


        [[[ 0.0000, -0.1389],
          [ 0.1389,  0.0000]],

         [[ 0.0000, -0.0160],
          [ 0.0160,  0.0000]]]], grad_fn=<CopySlices>)

(α=1) Scalar Curvature: tensor(0.2972, grad_fn=<AddBackward0>)

おわりに

本記事では、1次元正規分布の情報幾何の数値解析を行い、フィッシャー情報行列からクリストフェル記号、リーマン曲率テンソル、スカラー曲率まで計算しました。情報幾何は深い理論ですが、実際にコードを書くことで理解が深まるので、ぜひ実装を試してみてください!

12
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?