画像分類の入門であるMNISTに取り組んだとき一番最初に対処したエラーです。
参考書見ながら勉強しているとバージョンによるエラーがよくあるので、このあたりのアップデート情報には敏感にならないといけないなぁ
環境
google colaboratory
Python 3.7.13
sklearn 1.0.2
エラー内容
fetch_mldata
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata("mnist original", datahome=".")
X = mnist["data"].astype("float32")
y = mnist["data"].astype(int)
# エラーコード----------------------------------------------------------------------------
ImportError Traceback (most recent call last)
----> 1 from sklearn.datasets import fetch_mldata
2
3 mnist = fetch_mldata("mnist original", datahome=".")
4
5 mnist_X = mnist["data"].astype("float32")
ImportError: cannot import name 'fetch_mldata' from 'sklearn.datasets' (/usr/local/lib/python3.7/dist-packages/sklearn/datasets/__init__.py)
fetch_midata自体がもう古いっぽい
MNISTは有名なので他にいくらでも方法はありますが、とりあえずこの方法で読み込む場合の修正は以下のとおり
修正後コード
fetch_openml
from sklearn.datasets import fetch_openml
mnist_X, mnist_y = fetch_openml("mnist_784", version=1, return_X_y=True)
X = mnist_X.astype("float32")
y = mnist_y.astype(int)
「return_X_y=True」で訓練データとラベルデータを別々に分けて出力するように指定している。
これがないと一緒くたになって出力される。
中身はunit8型になっているようなので、取り出し後に型の変更をしておくこと。