LoginSignup
0
0

More than 1 year has passed since last update.

load_iris(sklearn)

Posted at

概要

静かな秋の夜に突発的に「そういえばsklearnってどうやってirisのデータ読み込んでいるのだろう」
と気になったので調べてみた。
(2022/11/17現在)

内容

早速load_irisのソースコードっぽいのをみてみる。(Github)

load_iris()

datasets/_base.py
    data_file_name = "iris.csv"
    data, target, target_names, fdescr = load_csv_data(
        data_file_name=data_file_name, descr_file_name="iris.rst"
    )

    feature_names = [
        "sepal length (cm)",
        "sepal width (cm)",
        "petal length (cm)",
        "petal width (cm)",
    ]

    frame = None
    target_columns = [
        "target",
    ]
    if as_frame:
        frame, data, target = _convert_data_dataframe(
            "load_iris", data, target, feature_names, target_columns
        )
    (中略)

    return Bunch(
        data=data,
        target=target,
        frame=frame,
        target_names=target_names,
        DESCR=fdescr,
        feature_names=feature_names,
        filename=data_file_name,
        data_module=DATA_MODULE,
    )

data_file_nameを指定(iris.csv)して、load_csv_dataで読み込んだ後、
それらをBunchで束ねて返しているのか。

load_csv_dataには他にiris.rstも指定していて、返り値は、data, target, target_names, fdescrに格納されて、feature_namesはここで指定、DataFrameも必要があれば追加される感じ。
なお、DATA_MODULE = "sklearn.datasets.data"

load_csv_data

datasets/_base.py
def load_csv_data(
    data_file_name,
    *,
    data_module=DATA_MODULE,
    descr_file_name=None,
    descr_module=DESCR_MODULE,
):
    (中略)
    with resources.open_text(data_module, data_file_name) as csv_file:
        data_file = csv.reader(csv_file)
        temp = next(data_file)
        n_samples = int(temp[0])
        n_features = int(temp[1])
        target_names = np.array(temp[2:])
        data = np.empty((n_samples, n_features))
        target = np.empty((n_samples,), dtype=int)

        for i, ir in enumerate(data_file):
            data[i] = np.asarray(ir[:-1], dtype=np.float64)
            target[i] = np.asarray(ir[-1], dtype=int)

    if descr_file_name is None:
        return data, target, target_names
    else:
        assert descr_module is not None
        descr = load_descr(descr_module=descr_module, descr_file_name=descr_file_name)
        return data, target, target_names, descr

(data, target, target_names, fdescr = load_csv_data(data_file_name=data_file_name, descr_file_name="iris.rst")
descrはやはりdescr_file_nameから読み込まれている模様

csv.readerでのnextはheaderの処理に用いられ、
ヘッダーの1列目がn_samples, 2列目がn_features, 3列目以降がtarget_namesにしてあるということ。(そこにfeature_namesはないんか)

resources.open_textについては後述。先にiris.csvを見る。

iris.csv

datasets/data/iris.csv
150,4,setosa,versicolor,virginica
5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
(中略)
5.9,3.0,5.1,1.8,2

iris.csvの本体は、datasets.dataにいた模様。(Github)
確かに1行目はヘッダーになっている。

resources.open_text

resourcesの出どころは、17行目

datasets/_base.py
from importlib import resources

importlib.resources.read_text(package, resource, encoding='utf-8', errors='strict')

Read and return the contents of resource within package as a str. By default, the contents are read as strict UTF-8. (link)

(with resources.open_text(data_module, data_file_name) as csv_file)
(DATA_MODULE = "sklearn.datasets.data")
上で見た通りのiris.csvをロードできるようになっている。

感想

iris.csvの場所がわかってよかった。あと、iris.csv自体にはfeature_namesが含まれていないことが意外だった。
ちなみに、ヘッダーがあるので、そのままread_csvをしないように注意が必要だなと思った。(そのままは使わないだろうけれど)

dtypeの指定は、関数内で行われているのも意外だった。

あと、csvの読み込みが1行1行行われていることも意外だった。

面白かった。

蛇足

datasets/tests/test_base.pyにテストコードあり

test_loader()

datasets/tests/test_base.py
@pytest.mark.filterwarnings("ignore:Function load_boston is deprecated")
@pytest.mark.parametrize(
    "loader_func, data_shape, target_shape, n_target, has_descr, filenames",
    [
        (load_breast_cancer, (569, 30), (569,), 2, True, ["filename"]),
        (load_wine, (178, 13), (178,), 3, True, []),
        (load_iris, (150, 4), (150,), 3, True, ["filename"]),
        (中略)
        (load_boston, (506, 13), (506,), None, True, ["filename"]),
    ],
)
def test_loader(loader_func, data_shape, target_shape, n_target, has_descr, filenames):
    bunch = loader_func()

    assert isinstance(bunch, Bunch)
    assert bunch.data.shape == data_shape
    assert bunch.target.shape == target_shape
    if hasattr(bunch, "feature_names"):
        assert len(bunch.feature_names) == data_shape[1]
    if n_target is not None:
        assert len(bunch.target_names) == n_target
    if has_descr:
        assert bunch.DESCR
    if filenames:
        assert "data_module" in bunch
        assert all(
            [
                f in bunch and resources.is_resource(bunch["data_module"], bunch[f])
                for f in filenames
            ]
        )

test_toy_dataset_frame_dtype()

@pytest.mark.parametrize(
    "loader_func, data_dtype, target_dtype",
    [
        (load_breast_cancer, np.float64, int),
        (load_diabetes, np.float64, np.float64),
        (load_digits, np.float64, int),
        (load_iris, np.float64, int),
        (load_linnerud, np.float64, np.float64),
        (load_wine, np.float64, int),
    ],
)
def test_toy_dataset_frame_dtype(loader_func, data_dtype, target_dtype):
    default_result = loader_func()
    check_as_frame(
        default_result,
        loader_func,
        expected_data_dtype=data_dtype,
        expected_target_dtype=target_dtype,
    )

test_bunch_dir()

def test_bunch_dir():
    # check that dir (important for autocomplete) shows attributes
    data = load_iris()
    assert "data" in dir(data)

出典

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0