Help us understand the problem. What is going on with this article?

時間比較:Python で相関係数計算

More than 1 year has passed since last update.

はじめに

大気海洋のデータ解析では,相関係数マップを描くことがよくある.
Pythonでは,NumpyやScipy に相関係数を計算する関数が用意されているので,それらを用いて相関係数のマップを描くことができる.
しかし,やり方が複数あるので,ここでは,それらのパフォーマンスを比較してみる.

開発環境

  • Python 2.7.13
  • Numpy 1.12.1
  • Scipy 0.19.0

相関係数を計算する方法

1. Scipy の関数を使ってループを回す方法

def crr_loop_scipy(x1, x2):
    ret = numpy.empty([x1.shape[1], x2.shape[1]])
    for jj in xrange(ret.shape[1]):
        for ii in xrange(ret.shape[0]):
            ret[ii,jj], _ = scipy.stats.pearsonr(x1[:,ii], x2[:,jj])
    return ret

2. Numpy の関数を使ってループを回す方法

def crr_loop_numpy(x1, x2):
    ret = numpy.empty([x1.shape[1], x2.shape[1]])
    for jj in xrange(ret.shape[1]):
        for ii in xrange(ret.shape[0]):
            ret[ii,jj] = numpy.corrcoef(x1[:,ii], x2[:,jj])[1,0]
    return ret

3. Numpy の関数を使うがループを回さない方法

def crr_numpy(x1, x2):
    return numpy.corrcoef(x1, x2, rowvar=False)[0:x1.shape[1], x1.shape[1]:]

4. 自分で共分散行列を計算する方法

def crr_original(x1, x2):
    crscov = np.dot(x1.T, x2) / x1.shape[0]
    std1 = np.std(x1, axis=0).reshape(x1.shape[1], 1)
    std2 = np.std(x2, axis=0).reshape(x2.shape[1], 1) 
    crsstd = np.dot(std1, std2.T)
    return crscov / crsstd

実験設定

  • 上記にある相関係数を計算する方法を4種類試す.
  • 計算するデータの時間方向の大きさは1000とし,x方向の大きさ(グリッドサイズ)は可変とする(下にこのデータを生成する関数を示す).
def generate_data(nx):
    nt = 1000
    data1 = numpy.random.randn(nt, nx)
    data2 = numpy.random.randn(nt, nx)
    return data1, data2
  • 時間の計測には,Python標準のtime モジュールを使用し,2回同じ実験を行った平均値をその実験の計算時間とする.
import time

def measure_time(f, *args):
    start = time.time()
    f(*args)
    return time.time() - start

def measure_mean_time(f, *args):
    n = 2
    return numpy.mean([measure_time(f, *args) for i in xrange(n)])

結果

crr_ventimark.png

結果は,

  • 3. Numpy の関数を使うがループを回さない方法(non-loop_np)
  • 4. 自分で共分散行列を計算する方法(original)

が圧倒的に早かった.
また,numpy よりもscipy のほうが早いので,今回のように相関係数マップを目的としなくても,scipy のほうを使った方がよいかもしれない.

実験に使用したソースコード全体

import numpy as np
from scipy import stats
import time


def crr_loop_scipy(x1, x2):
    ret = np.empty([x1.shape[1], x2.shape[1]])
    for jj in xrange(ret.shape[1]):
        for ii in xrange(ret.shape[0]):
            ret[ii,jj], _ = stats.pearsonr(x1[:,ii], x2[:,jj])
    return ret


def crr_loop_numpy(x1, x2):
    ret = np.empty([x1.shape[1], x2.shape[1]])
    for jj in xrange(ret.shape[1]):
        for ii in xrange(ret.shape[0]):
            ret[ii,jj] = np.corrcoef(x1[:,ii], x2[:,jj])[1,0]
    return ret


def crr_numpy(x1, x2):
    return np.corrcoef(x1, x2, rowvar=False)[0:x1.shape[1], x1.shape[1]:]


def crr_original(x1, x2):
    crscov = np.dot(x1.T, x2) / x1.shape[0]
    std1 = np.std(x1, axis=0).reshape(x1.shape[1], 1)
    std2 = np.std(x2, axis=0).reshape(x2.shape[1], 1) 
    crsstd = np.dot(std1, std2.T)
    return crscov / crsstd


def generate_data(nx):
    nt = 1000
    data1 = np.random.randn(nt, nx)
    data2 = np.random.randn(nt, nx)
    return data1, data2


def measure_time(f, *args):
    start = time.time()
    f(*args)
    return time.time() - start


def measure_mean_time(f, *args):
    n = 2
    return np.mean([measure_time(f, *args) for i in xrange(n)])


def experiment(nxs, crrs):
    mean_times = np.empty([len(crrs), len(nxs)])
    for icrr, crr in enumerate(crrs):
        for ix, nx in enumerate(nxs):
            x1, x2 = generate_data(nx)
            mean_times[icrr, ix] = measure_mean_time(crr, x1, x2)
    return mean_times


if __name__ == '__main__:
    import matplotlib.pyplot as plt

    nxs1 = range(200, 200 + 1000, 200)
    nxs2 = range(200, 200 + 10000, 600)
    mean_times1 = experiment(nxs1, [crr_loop_scipy, crr_loop_numpy])
    mean_times2 = experiment(nxs2, [crr_numpy, crr_original])

    lebels1 = ['loop_sci', 'loop_np']
    labels2 = ['np', 'original']
    for label1, label2, mean_time1, mean_time2 in zip(lebels1, labels2, mean_times1, mean_times2):
        plt.plot(nxs1, mean_time1, label=label1)
        plt.plot(nxs2, mean_time2, label=label2)

    plt.xlabel('nx')
    plt.ylabel('time (sec)')
    plt.title('nt=1000')
    plt.xlim([0, nxs2[-1]])
    plt.ylim([0, 80])
    plt.legend()
    plt.show()
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした