統計学でよく使われるカイ二乗検定。サンプル群にお互い関連性があるかどうか帰無仮説で検証するやつです。これをプログラムできたら楽しいだろうなと思い組んでみました。
カイ二乗検定の計算方法については下記を参照しました。
まずはPythonから
コード
kai.py
import numpy as np
def chi_squared(array):
# 縦横合計値を取得
low_sums = np.array(array).sum(axis=1)
col_sums = np.array(array).sum(axis=0)
# 総合計を取得
the_total = sum(low_sums)
# 期待度数を取得
ex_freq = []
for i in low_sums:
for j in col_sums:
ex_freq.append(i * j/the_total)
pass
ex_freq = np.array(np.array_split(ex_freq, len(array)))
diff = np.array(array) - ex_freq
# 多行多列に対応できるように
ex_freq_flt = ex_freq.flatten()
diff_flt = diff.flatten()
return sum(diff_flt ** 2 / ex_freq_flt)
pass
def d_f(array):
s = list(np.array(array).shape)
d = 1
for i in s:
d *= (i - 1)
return d
waku_mogu = [[435, 165],
[265, 135]]
print("カイ2乗数は", chi_squared(waku_mogu))
print("自由度は ", d_f(waku_mogu))
"""
実行結果
カイ2乗数は 4.464285714285714
自由度は 1
"""
列の和、Numpyで一発ででるんですね。すごいです。
つづいて同じ動きをするプログラムをJavaで
コード
Kai.java
import java.util.ArrayList;
import java.util.List;
import java.util.Iterator;
public class Kai {
public static void main(String[] args) {
Calc c = new Calc();
double[][] waku_mogu = {{435, 165}, {265, 135}};
System.out.println("カイ2乗数は " + c.chi_squared(waku_mogu));
System.out.println("自由度は " + c.d_f(waku_mogu));
}
}
class Calc {
public double chi_squared(double[][] arr) {
List<Double> low_sums = new ArrayList<>();
List<Double> col_sums = new ArrayList<>();
List<Double> ex_freq = new ArrayList<>();
List<Double> diff = new ArrayList<>();
//各行の合計値を求める。
for (int i = 0; i < arr.length; i++) {
double total_l = 0;
for (int j = 0; j < arr[i].length; j++) {
total_l += arr[i][j];
}
low_sums.add(total_l);
}
//各列の合計値を求める。これが一番苦労した点...Numpyだと一行なのに
for (int j = 0; j < arr[0].length; j++) {
double total_c = 0;
for (int i = 0; i < arr.length; i++) {
total_c += arr[i][j];
}
col_sums.add(total_c);
}
double the_total = 0;
Iterator<Double> iterator = low_sums.iterator();
while (iterator.hasNext()) {
double i = iterator.next();
the_total += i;
}
iterator = low_sums.iterator();
while (iterator.hasNext()) {
double i = iterator.next();
Iterator<Double> iterator2 = col_sums.iterator();
while (iterator2.hasNext()) {
double j = iterator2.next();
ex_freq.add(i * j / the_total);
}
}
// 多行多列に対応できるように、二番目に苦労した点
int count = 0;
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
diff.add(arr[i][j] - ex_freq.get(count));
count++;
}
}
double chi_val = 0;
for (int i = 0; i < ex_freq.size(); i++) {
chi_val += Math.pow(diff.get(i), 2) / ex_freq.get(i);
}
return chi_val;
}
public int d_f(double[][] arr) {
return (arr.length - 1) * (arr[0].length - 1);
}
}
/*
実行結果
カイ2乗数は 4.464285714285714
自由度は 1
*/
列の和を出すところに相当苦労しました。
自分が下手なのももちろんですが、Pythonのほうがかなりスッキリ記述できますね。
でもJavaも自分で安全を確認しながら組んでいっている感じが好きでもあります。