LoginSignup
0
2

More than 3 years have passed since last update.

matplotlibで予測結果を3D表示

Last updated at Posted at 2021-01-22

はじめに

深層学習で予測した結果のうち誤差が大きいものと,誤差が小さいものに特徴はあるのか可視化して確認したかったので,3Dグラフにプロットしてみました.
* 予測に失敗:赤
* 予測に成功:青
で示しています.

使ったライブラリ

  • matplotlib
  • pandas
  • Pillow

3D表示

Matplotlib公式が色分けの散布図の表示の仕方を載せていたので参考にしました.
mplot3d tutorial — Matplotlib 2.0.2 documentation

使うデータ

x, y, z, TFとラベル付けしました.

描画

plot3d.py
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from io import BytesIO
from PIL import Image

import_path = './test.csv'


def main():
    fig = plt.figure(figsize=(10.0, 9.0))
    ax = fig.add_subplot(111, projection='3d')

    data = pd.read_csv(import_path, header=0)
    print(data)
    for index, row in data.iterrows():
        print('\rindex = {}/{}, TorF = {}  '.format(index,len(data), str(row['TF'])),end="")
        color = 'r'
        marker = 'x'
        if str(row['TF']) == 'True':
            color = 'b'
            marker = 'o'
        ax.scatter(row['x'], row['y'], row['z'], c = color, marker = marker)

    print('\n---Finished Loading---')
    ax.set_xlabel("x", fontsize=20)
    ax.set_ylabel("y", fontsize=20)
    ax.set_zlabel("z", fontsize=20)
    ax.view_init(30, 0)
    plt.savefig('./output2D-1.png')
    ax.view_init(30, 90)
    plt.savefig('./output2D-2.png')
    ax.view_init(30, 180)
    plt.savefig('./output2D-3.png')


if __name__ == '__main__':
    main()

描画した結果

output2D-1.png
output2D-2.png
output2D-3.png

3D表示を回転させる

GIFアニメにすることで回転させてみました.
こちらの記事を参考にさせていただきました.
3D 散布図を回転 GIF アニメーションにする - Qiita

plot3d.py
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from io import BytesIO
from PIL import Image
import time
import datetime

import_path = './test.csv'


def  render_frame(ax, angle):
    print('\rangle = {}'.format(angle),end="")
    ax.view_init(30, angle)
    buf = BytesIO()
    plt.savefig(buf, bbox_inches='tight', pad_inches=0.0)
    return Image.open(buf)


def main():
    start = time.time()
    fig = plt.figure(figsize=(10.0, 9.0))
    ax = fig.add_subplot(111, projection='3d')

    data = pd.read_csv(import_path, header=0)
    for index, row in data.iterrows():
        print('\rindex = {}/{}, TorF = {}  '.format(index,len(data), str(row['TF'])),end="")
        color = 'r'
        marker = 'x'
        if str(row['TF']) == 'True':
            color = 'b'
            marker = 'o'
        ax.scatter(row['x'], row['y'], row['z'], c = color, marker = marker)

    ax.set_xlabel("x", fontsize=20)
    ax.set_ylabel("y", fontsize=20)
    ax.set_zlabel("z", fontsize=20)

    print('\n---Finished Loading---\n')
    images = [render_frame(ax, angle) for angle in range(50)]
    images[0].save('output3D.gif', save_all=True, append_images=images[1:], duration=100, loop=0)

    # 経過時間の集計
    process_time = time.time() - start
    td = datetime.timedelta(seconds=process_time)
    print('PROCESS TIME = {}'.format(td))


if __name__ == '__main__':
    main()

GIFアニメ

output3D.gif

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2