機械学習を簡単に行うための有用なライブラリの一つはscikit-learnである。
この文書では、scikit-learnの学習結果をpickleしないで済ます方法について述べる。
scikit-learnの特徴
- 各種識別器の学習・予測・評価方法のためのインターフェースがそろえてある設計。
- 各種アルゴリズムを試して、比較しやすい。
- ドキュメントが充実している。
前提
- python
- scikit-learn
- pickle
scikit-learn に欠けているもの
- scikit-learnで学習した結果を保持するための枠組みが不足している。
- そのため、sckit-learnで作った学習済みの識別器をpickleして、それをpickl.loads(pickle済みのファイル)して使ってしまうということをしてしまいやすい。
問題点
scikit-learn のサイトでも、pickleしたものを使うことの問題点が述べられている。
解決策
sklean-onnx のライブラリを用いて、学習済みの識別器をonnxファイルに変換する。そのonnxファイルを元にonnxruntimeを使って推論する。
このページの中で、以下の手づきの例があるので、それをたどればよい。
- Train a model.
- Convert into ONNX format
- Compute the prediction with ONNX Runtime
一連の動作の実施例
- sckit-learn で学習する。
- sklean-onnx を使ってonnxファイルに変換する。
- onnxファイルを元に推論を実行する。
sample.py
# Train a model.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)
print(clr.predict(X_test))
# Convert into ONNX format
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
f.write(onx.SerializeToString())
# Compute the prediction with ONNX Runtime
import onnxruntime as rt
import numpy
sess = rt.InferenceSession("rf_iris.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(pred_onx)