LoginSignup
7
4

More than 5 years have passed since last update.

TensorFlowのtutorial「tf.contrib.learn Quickstart」を試す

Posted at

はじめに

TensorFlowのtutorial tf.contrib.learn Quickstartを試したので共有します。

なお、2016/10/19現在、tutorial内のサンプルコードはエラーが出て動きませんでしたので、そちらについても共有します。

コード全体

コードの全体は下記となります。

sample.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

# Data sets
# 120sample
IRIS_TRAINING = "iris_training.csv"

# 30 sample
IRIS_TEST = "iris_test.csv"

# Load datasets.
# 下記だと動かない。
# load_csvがdeprecatedとのこと。
# training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,target_dtype=np.int)
# test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST,
#                                                    target_dtype=np.int)

# load_csv_with_headerを用いる
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,target_dtype=np.int, features_dtype=np.float32, target_column=-1)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TEST,
                                                   target_dtype=np.int, features_dtype=np.float32, target_column=-1)


# DNNの構築

# Specify that all features have real-value data
# four features(sepal w,sepal h,petal w,petal h)
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
# n_classes は 3つのirisの種類を示す
# checkpoint dataの保存先
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")



# Fit model.
# classifierの中にモデルの状態は保存されているので、
# 下記を2行書けば4000回の試行になる
classifier.fit(x=training_set.data,
               y=training_set.target,
               steps=2000)



# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
                                     y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))


# Classify two new flower samples.
# predictというメソッドを使います。
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict(new_samples)
print('Predictions: {}'.format(str(y)))

データセットのダウンロード

本チュートリアルのデータセットはirisデータになります。
featureはsepal length, sepal width, petal length, petal widthの4つで、targetはIris setosa, Iris virginica,Iris versicolorの4つで、irisのどの種類かを判別するモデルを作るチュートリアルになります。
trainデータは120サンプル,testデータは30サンプルになります。

sample.py
# 120sample
IRIS_TRAINING = "iris_training.csv"

# 30 sample
IRIS_TEST = "iris_test.csv"

データセットの読み込み

ここでつまづきましたが、tutorialのままですと、下記のエラーが出ると思います。

AttributeError: 'module' object has no attribute 'load_csv'

調べてみると、load_csvはdeprecatedとなっており、load_csv_with_headerに変更する必要がありました。

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,target_dtype=np.int, features_dtype=np.float32, target_column=-1)

では、target_dtypeでファイル中のtarget(花の種類)の列のフォーマット、features_dtypeでfeatureの各列のフォーマット、target_columnでtargetがどの列に記載されているか(今回は一番後ろなので-1)を指定しています。

sample.py
# Load datasets.

# 下記だと動かない
# training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,target_dtype=np.int)
# test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST,
#                                                    target_dtype=np.int)

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,target_dtype=np.int, features_dtype=np.float32, target_column=-1)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TEST,target_dtype=np.int, features_dtype=np.float32, target_column=-1)

DNNの構築

今回featuresはsepal length, sepal width, petal length, petal widthの4つなので4と指定しています。

また、hidden layerは3つでそれぞれ10,20,10
分類すべき花の種類は3つでn_classes = 3
checkpoint dataの保存先を/tmp/iris_modelと指定します。

sample.py
# Specify that all features have real-value data
# four features(sepal w,sepal h,petal w,petal h)
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
# n_classes は 3つのirisの種類を示す
# checkpoint dataの保存先
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")

モデルのフィット

classifierの中にモデルの状態は保存されているので、下記を2行繰り返すと、2000*2=4000回の試行が行われます。

sample.py
# Fit model.
# classifierの中にモデルの状態は保存されているので、
# 下記を2行書けば4000回の試行になる
classifier.fit(x=training_set.data,
               y=training_set.target,
               steps=2000)

accuracyの計算

evaluateというメソッドで、fitと同様に検証することができます。

sample.py
# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
                                     y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

おおよそ、0.96程度の数字が帰ってくると思います。

Accuracy: 0.966667

新たなデータの分類

predictというメソッドを使うことで、新しいデータが来た時に分類することができます。

sample.py
# Classify two new flower samples.
# predictというメソッドを使います。
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict(new_samples)
print('Predictions: {}'.format(str(y)))

それぞれのデータの予測がarrayで帰ってきます。

Predictions: [1 2]

さいごに

簡単にDNNが実装できますね。

ありがとうございました。

7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4