Python
matplotlib
自動化
グラフ
マルチエージェントシステム

Keiです。前回投稿したPythonで人工社会シミュレーションの記事で予想外に沢山のいいねを頂けてニヤニヤしています。

前回の投稿ではひとまず人工社会でエージェントたちに好き勝手ゲームをさせるとこまでは行きましたが、おまけに載せているような綺麗なヒートマップの作り方は説明していませんでした。
そこで今回はPandasとMatplotとSeabornを使ってシミュレーション結果の可視化手法を説明していきたいと思いますが、ただグラフを書くだけなら他の記事でも散々紹介されているので、せっかくなので今回は「とにかく大量のcsvを読み込んでヒートマップやら散布図を書かなきゃいけない!こんなんExcelでやってたら日が暮れちまう!」という方向けに、大量のデータを一発でグラフにしてしまうTipsを紹介していきます。

やりたいことその1.大量のヒートマップを一度に描いて見比べる

例えば
output1.csv
output2.csv
output3.csv
.
.
.
output100.csv
のようにファイル毎に数字の連番が降ってあるcsvファイルを一気に全部ヒートマップにしたい!みたいな時ってたまーにないですか?
特にマルチエージェントシミュレーションとかやってると、計算結果が毎回実行するごとに微妙に変わる(使用する乱数系列の関係)ので、100回くらい違う乱数系列を使って計算してそれぞれのエピソードで計算結果がどのくらい違ってくるのかをチェックしたい場面が頻繁にあり、そんな時に以下のスクリプト達が役に立ってくれています。
なおcsvファイルの名前は上で書いたように"output(数字).csv" のようになっている場合を想定しています。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

num_figure = 10    # 描画したいcsvファイルの数

fig = plt.figure()

for i in range(1, num_figure+1):
    df = pd.read_csv('output{}.csv'.format(i))    # csvファイル読み込み。
    df_pivot = df.pivot('col1','col2', 'value')   # col1, col2, valueの部分は読み込みたいデータに応じて変えてください

    ax = fig.add_subplot(2, 5, i)                 # これだと10枚のヒートマップを2行5列の配置で順に描画してくれます。
    sns.heatmap(df_pivot, cbar=False)             # cbar=Trueにすると同じカラーバーが10個出てくるので鬱陶しければこの例のようにFalseにしておく。
    ax.set_title("Output{}".format(i))            # それぞれのヒートマップにファイル名に応じたタイトルをつける

plt.tight_layout()                                # 見た目調整
fig.suptitle('ヒートマップ10枚見比べてみた')          # これはグラフ全体のタイトル
plt.show()
fig.savefig('heatmaps.png')

上ではfor文とformat文を使ってファイルごとに数字が異なっている部分を変えながらデータを読み込んでいってます。ちなみにこんな感じで10枚のヒートマップを並べて見比べることが出来ます↓


これは僕が以前研究中に人工社会シミュレーションの計算結果を10個のエピソードについて見比べた際のものですが、Episode7だけ他のエピソードより色の明るい範囲が広くなっています(=より強い社会ジレンマの下でも利他的な人間が生き残れている)。これは計算時に使用した乱数系列の違いで、たまたま利他的なエージェントにとって有利な初期条件で計算が始まったのがEpisode7だったということになります。

やりたいことその2.大量の時系列データを一つのグラフにまとめて描画

上と似たような感じで大量の時系列データも一つのグラフに押し込んで比較出来ます。
こんな感じで↓

import pandas as pd
import matplotlib.pyplot as plt

num_figure = 10    # 描画したいcsvファイルの数

ax = plt.subplot(1,1,1)

for i in range(1, num_figure+1):
    df = pd.read_csv('output{}.csv'.format(i))
    ax.scatter(df['time'], df['column'], label='output{}'.format(i))    #    
    ax.set_title('Time evolution of Fc', fontsize = 20)
    ax.set_xlabel('Time')
    ax.set_ylabel('Column')
    ax.legend()                      

plt.tight_layout()
plt.savefig('time_series_data.png')
plt.show()

こっちの方が簡単です。順番逆にすべきでしたねはい。
なんの差かというと、ヒートマップだと今回みたいにax.heatmapみたいな書き方が出来ない(seaborn使わないとヒートマップは描けない)からですね。
ちなみに上のスクリプト使って10個のcsv読み込んで一気に描画するとこんな感じのグラフが作れたりして↓

この例みたく10個の時系列データくらいならExcelでももちろん出来ますけど、これが100個とかなってくると流石にダルいじゃないですか。そういう時この手のスクリプト作っとけばpython script.pyって打てば終わりなんでありがたいですよね。