0
0

二次元〜多次元正規分布データの生成

Last updated at Posted at 2024-08-03

昔,以下のような関数を作りました。
2次元正規分布から,k次元正規分布までに対応しています。
指定した相関係数の通りのデータを作ります。


指定された相関係数行列を持つ変数の生成

特定の相関係数行列を持つ多変量データを生成する。

import numpy as np
from scipy.linalg import eigvals, inv, svd, cholesky
from scipy.stats import zscore

def gendat(n, r):
    #
    #
    def trimat(x):
        l = len(x)
        n = int((np.sqrt(1 + 8 * l) + 1) / 2)
        if l != n * (n - 1) / 2:
            raise Exception("length of vector is not just required")
        r = np.tri(n, n, -1)
        r[r == 1] = x
        r = r + r.T + np.identity(n)
        return r
    #
    #
    r = np.array(r)
    if np.ndim(r) == 0:
        r = np.array([[1, r], [r, 1]])
    elif np.ndim(r) == 1:
        r = trimat(r)
    size = r.shape
    if size[0] != size[1]:
        raise Exception("'r' must be a square matrix.")
    if any(np.diag(r) != 1):
        raise Exception("some diagonal of 'r' is not 1.")
    if any(np.ravel(r != r.T)):
        raise Exception("'r' is not a symmetry matrix.")
    if any(abs(np.ravel(r)) > 1):
        raise Exception("some element of 'r' is not in a range [-1, 1].")
    if any(eigvals(r) <= 0):
        raise Exception("'r' is not positive definite.")
    x = np.random.randn(n, size[0])
    x = zscore(x, ddof=1)
    r2 = np.corrcoef(x.T)
    solver2 = inv(r2)
    vec, val, junk = svd(r2, full_matrices=False)
    coeff = solver2 @ (np.sqrt(val) * vec)
    z = x @ coeff @ cholesky(r)
    return z

使用法

gendat(n, r)

引数

n サンプルサイズ
r 1 個の相関係数または相関係数行列またはその下三角行列のベクトル

戻り値

データ行列

使用例

2 変数データの生成

第 2 引数に,1 個の相関係数を指定する。

a = gendat(10,0.5)
print(np.corrcoef(a, rowvar=False))
[[1.  0.5]
 [0.5 1. ]]

3 変数以上のデータ生成

下三角行列をベクトルで与える場合

x = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
a = gendat(10, x)
print(a)
print(np.corrcoef(a, rowvar=False))
[[-0.53063107  1.03494526  1.07697067  0.6647251 ]
 [ 1.80820393  0.84075517  0.63006879  1.3337293 ]
 [ 1.02865631  0.20310229 -0.9718502   0.256431  ]
 [-0.87312874  0.19745704 -1.60090653 -1.95769795]
 [ 0.99255704 -1.68064095  0.80260234 -0.80927951]
 [-1.19798238 -1.86824642 -1.06679755 -0.86365449]
 [-1.08642493  0.51215329  1.02922946  0.65986376]
 [-0.08134605  0.51709947 -0.2893305  -0.46295722]
 [ 0.06823093 -0.24408027 -0.48852991  0.84976087]
 [-0.12813504  0.48745513  0.87854342  0.32907914]]
[[1.  0.1 0.2 0.4]
 [0.1 1.  0.3 0.5]
 [0.2 0.3 1.  0.6]
 [0.4 0.5 0.6 1. ]]

相関係数行列を与える場合

r = [[1.0, 0.5, 0.4], [0.5, 1.0, 0.3], [0.4, 0.3, 1.0]]
a = gendat(2000, r)
print(np.corrcoef(a, rowvar=False))
[[1.  0.5 0.4]
 [0.5 1.  0.3]
 [0.4 0.3 1. ]]
import matplotlib.pyplot as plt
plt.axes().set_aspect('equal', 'datalim')
plt.scatter(a[:, 0], a[:, 1], s=9, alpha=0.3)
plt.text(1.0, -3.0, "r = "+str(np.corrcoef(a.T)[0, 1]))
plt.xlabel("x1")
plt.ylabel("x2")

output_8_1.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0