記事を書いたきっかけ
下の2つの書籍で機械学習を勉強していたところ、両方で紹介されていたコードをそのまま書いても3次元グラフの描画ができませんでした。
(少し古いことを除けば、書籍としては滅茶苦茶おすすめです!)
from mpl_toolkits.mplot3d import Axes3D
xx = "sepal width (cm)"
yy = "sepal length (cm)"
zz = "petal length (cm)"
fig = plt.figure(figsize=(5, 5))
ax = Axes3D(fig)
ax.scatter(df0[xx], df0[yy], df0[zz], color="b")
ax.scatter(df1[xx], df1[yy], df1[zz], color="r")
ax.scatter(df2[xx], df2[yy], df2[zz], color="g")
ax.set_xlabel(xx)
ax.set_ylabel(yy)
ax.set_zlabel(zz)
plt.show()
<Figure size 500x500 with 0 Axes>
本来であれば3d散布図が描画されるはずですが、正常に出力されません
with 0 Axesということで、空の図が作成されているようです
描画されなかった原因
Matplotlib公式ドキュメント
Axes3D automatically adding itself to Figure is deprecated
New Axes3D objects previously added themselves to figures when they were created, unlike all other Axes classes, which lead to them being added twice if fig.add_subplot(111, projection='3d') was called.This behavior is now deprecated and will warn. The new keyword argument auto_add_to_figure controls the behavior and can be used to suppress the warning. The default value will change to False in Matplotlib 3.5, and any non-False value will be an error in Matplotlib 3.6.
In the future, Axes3D will need to be explicitly added to the figure
一言で言うと
「Axes3D
オブジェクトが自動的にFigure
オブジェクトに追加される動作は非推奨」
とのこと。
fig = plt.figure(figsize=(5, 5))
ax = Axes3D(fig)
上記部分にて、Axes3D
オブジェクトとFigure
オブジェクトは作成していますが、Matplotlibのバージョン3.4からは、Axes3D
オブジェクトを生成しても自動的にはFigure
オブジェクトには追加されなくなったようです。
解決方法
公式ドキュメントに書かれている通り、明示的に書けば描画されました。
# add to Figure
fig.add_axes(ax)
from mpl_toolkits.mplot3d import Axes3D
xx = "sepal width (cm)"
yy = "sepal length (cm)"
zz = "petal length (cm)"
fig = plt.figure(figsize=(5, 5))
ax = Axes3D(fig)
#ここを追加↓
fig.add_axes(ax)
ax.scatter(df0[xx], df0[yy], df0[zz], color="b")
ax.scatter(df1[xx], df1[yy], df1[zz], color="r")
ax.scatter(df2[xx], df2[yy], df2[zz], color="g")
ax.set_xlabel(xx)
ax.set_ylabel(yy)
ax.set_zlabel(zz)
plt.show()
他にも方法があるらしい
一般的にmatplotlibを使用して3Dグラフを作成するときは、Axes3D
メソッドではなく、Figure
クラスで定義されているadd_subplot
メソッドの方がよくつかわれているようです。
from mpl_toolkits.mplot3d import Axes3D
xx = "sepal width (cm)"
yy = "sepal length (cm)"
zz = "petal length (cm)"
fig = plt.figure(figsize=(5, 5))
#ここ↓
ax = fig.add_subplot(111, projection="3d")
ax.scatter(df0[xx], df0[yy], df0[zz], color="b")
ax.scatter(df1[xx], df1[yy], df1[zz], color="r")
ax.scatter(df2[xx], df2[yy], df2[zz], color="g")
ax.set_xlabel(xx)
ax.set_ylabel(yy)
ax.set_zlabel(zz)
plt.show()
Axes3D
メソッド(with明示的追加)とadd_subplot
メソッドどちらでも同様に描画できるみたいです。