概要
静かな秋の夜に突発的に「そういえばsklearnってどうやってirisのデータ読み込んでいるのだろう」
と気になったので調べてみた。
(2022/11/17現在)
内容
早速load_irisのソースコードっぽいのをみてみる。(Github)
load_iris()
data_file_name = "iris.csv"
data, target, target_names, fdescr = load_csv_data(
data_file_name=data_file_name, descr_file_name="iris.rst"
)
feature_names = [
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)",
]
frame = None
target_columns = [
"target",
]
if as_frame:
frame, data, target = _convert_data_dataframe(
"load_iris", data, target, feature_names, target_columns
)
(中略)
return Bunch(
data=data,
target=target,
frame=frame,
target_names=target_names,
DESCR=fdescr,
feature_names=feature_names,
filename=data_file_name,
data_module=DATA_MODULE,
)
data_file_name
を指定(iris.csv
)して、load_csv_data
で読み込んだ後、
それらをBunch
で束ねて返しているのか。
load_csv_data
には他にiris.rst
も指定していて、返り値は、data
, target
, target_names
, fdescr
に格納されて、feature_names
はここで指定、DataFrameも必要があれば追加される感じ。
なお、DATA_MODULE = "sklearn.datasets.data"
load_csv_data
def load_csv_data(
data_file_name,
*,
data_module=DATA_MODULE,
descr_file_name=None,
descr_module=DESCR_MODULE,
):
(中略)
with resources.open_text(data_module, data_file_name) as csv_file:
data_file = csv.reader(csv_file)
temp = next(data_file)
n_samples = int(temp[0])
n_features = int(temp[1])
target_names = np.array(temp[2:])
data = np.empty((n_samples, n_features))
target = np.empty((n_samples,), dtype=int)
for i, ir in enumerate(data_file):
data[i] = np.asarray(ir[:-1], dtype=np.float64)
target[i] = np.asarray(ir[-1], dtype=int)
if descr_file_name is None:
return data, target, target_names
else:
assert descr_module is not None
descr = load_descr(descr_module=descr_module, descr_file_name=descr_file_name)
return data, target, target_names, descr
(data, target, target_names, fdescr = load_csv_data(data_file_name=data_file_name, descr_file_name="iris.rst"
)
descr
はやはりdescr_file_name
から読み込まれている模様
csv.reader
でのnext
はheaderの処理に用いられ、
ヘッダーの1列目がn_samples
, 2列目がn_features
, 3列目以降がtarget_names
にしてあるということ。(そこにfeature_namesはないんか)
resources.open_text
については後述。先にiris.csv
を見る。
iris.csv
150,4,setosa,versicolor,virginica
5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
(中略)
5.9,3.0,5.1,1.8,2
iris.csvの本体は、datasets.dataにいた模様。(Github)
確かに1行目はヘッダーになっている。
resources.open_text
resources
の出どころは、17行目
from importlib import resources
importlib.resources.read_text(package, resource, encoding='utf-8', errors='strict')
Read and return the contents of resource within package as a str. By default, the contents are read as strict UTF-8. (link)
(with resources.open_text(data_module, data_file_name) as csv_file
)
(DATA_MODULE = "sklearn.datasets.data"
)
上で見た通りのiris.csv
をロードできるようになっている。
感想
iris.csv
の場所がわかってよかった。あと、iris.csv
自体にはfeature_names
が含まれていないことが意外だった。
ちなみに、ヘッダーがあるので、そのままread_csvをしないように注意が必要だなと思った。(そのままは使わないだろうけれど)
dtypeの指定は、関数内で行われているのも意外だった。
あと、csvの読み込みが1行1行行われていることも意外だった。
面白かった。
蛇足
datasets/tests/test_base.pyにテストコードあり
test_loader()
@pytest.mark.filterwarnings("ignore:Function load_boston is deprecated")
@pytest.mark.parametrize(
"loader_func, data_shape, target_shape, n_target, has_descr, filenames",
[
(load_breast_cancer, (569, 30), (569,), 2, True, ["filename"]),
(load_wine, (178, 13), (178,), 3, True, []),
(load_iris, (150, 4), (150,), 3, True, ["filename"]),
(中略)
(load_boston, (506, 13), (506,), None, True, ["filename"]),
],
)
def test_loader(loader_func, data_shape, target_shape, n_target, has_descr, filenames):
bunch = loader_func()
assert isinstance(bunch, Bunch)
assert bunch.data.shape == data_shape
assert bunch.target.shape == target_shape
if hasattr(bunch, "feature_names"):
assert len(bunch.feature_names) == data_shape[1]
if n_target is not None:
assert len(bunch.target_names) == n_target
if has_descr:
assert bunch.DESCR
if filenames:
assert "data_module" in bunch
assert all(
[
f in bunch and resources.is_resource(bunch["data_module"], bunch[f])
for f in filenames
]
)
test_toy_dataset_frame_dtype()
@pytest.mark.parametrize(
"loader_func, data_dtype, target_dtype",
[
(load_breast_cancer, np.float64, int),
(load_diabetes, np.float64, np.float64),
(load_digits, np.float64, int),
(load_iris, np.float64, int),
(load_linnerud, np.float64, np.float64),
(load_wine, np.float64, int),
],
)
def test_toy_dataset_frame_dtype(loader_func, data_dtype, target_dtype):
default_result = loader_func()
check_as_frame(
default_result,
loader_func,
expected_data_dtype=data_dtype,
expected_target_dtype=target_dtype,
)
test_bunch_dir()
def test_bunch_dir():
# check that dir (important for autocomplete) shows attributes
data = load_iris()
assert "data" in dir(data)
出典
- https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/datasets/_base.py
- https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/datasets/data/iris.csv
- https://docs.python.org/ja/3/library/importlib.resources.html#module-importlib.resources
- https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/datasets/tests/test_common.py