LoginSignup
107
102

More than 5 years have passed since last update.

TensorFlowをscikit-learnライクに使えるskflow

Last updated at Posted at 2015-12-09

Pythonで機械学習を行う際によく使われるscikit-learnから、最近話題のTensorFlowを呼び出して使えてしまうという、目がくらむようなライブラリが登場しました。しかもgoogle謹製。それがskflowです。

TF Learn
※0.8.0から、TensorFlow本体に組み込まれました

scikit-learnには今までニューラルネットワーク系の実装がなかったのですが、これによりニューラルネットワークはもとより、ディープラーニングまで扱えるようになります。
また、TensorFlow側にとってもscikit-learn側に豊富にあるデータの前処理の機能(Preprocessingなど)と連携できるのは大きなメリットです。

そういう意味では、この二つをつなぐskflowの登場により機械学習処理の構築は非常にやりやすくなると思います。

インストール

TensorFlowの0.8.0からskflowは本体に取り込まれたため、個別のインストールは不要になりました。オプションとして、scikit-learnnumpypandasをインストールすると便利ですがなくても使えます。

使い方

GitHubからの引用になりますが、線形分類機は以下のようにかけます。

from tensorflow.contrib import skflow
from sklearn import datasets, metrics

iris = datasets.load_iris()
classifier = skflow.TensorFlowLinearClassifier(n_classes=3)
classifier.fit(iris.data, iris.target)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)

skflow.TensorFlowLinearClassifierがTensorFlow側とのIFになっています。この後は、fitなどscikit-learnではおなじみのAPIで処理を進めていくことができます。

ニューラルネットワークについては、以下のようになります。TensorFlowDNNClassifierを利用し、隠れ層が10-20-10のニューラルネットワークを構築しています。

from tensorflow.contrib import skflow
from sklearn import datasets, metrics

iris = datasets.load_iris()
classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
classifier.fit(iris.data, iris.target)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)

より詳細にモデルを構築したい場合は、以下のように行います。skflow.opsで直接構築していく感じです。

from tensorflow.contrib import skflow
from sklearn import datasets, metrics

iris = datasets.load_iris()

def my_model(X, y):
    """This is DNN with 10, 20, 10 hidden layers, and dropout of 0.5 probability."""
    layers = skflow.ops.dnn(X, [10, 20, 10], dropout=0.5)
    return skflow.models.logistic_regression(layers, y)

classifier = skflow.TensorFlowEstimator(model_fn=my_model, n_classes=3)
classifier.fit(iris.data, iris.target)
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
print("Accuracy: %f" % score)

簡単な紹介となりましたが、逆に紹介することがあまりないくらいシンプルです。
TensorFlowに興味はあるけどまた新しい書き方覚えるのはちょっとな・・・という方は、ぜひ見ていただければと思います。

107
102
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
107
102