LoginSignup
9
8

More than 5 years have passed since last update.

TensorFlowにネットワーク越しのデータを食べさせてみる

Posted at

これは日本情報クリエイト Engineers Advent Calendar 2016による6日目の記事になります。

今回はGoogleが提供する機械学習のライブラリであるTensorFlowの学習データをネットワークから取得する方法について記述します。

環境の構築は本家の説明を参照していただくとして早速本題に入っていきましょう。

※ちなみに自分はDocker使ってます。(楽なので。。。)

説明

今回は画像データの学習をしていきたいと思います。

だいたいWebでサービスを提供している会社であれば既に大量の画像データを持っていることも多いでしょうし、せっかくだからそれを学習データとして利用したいのでネットワーク越しの画像を使っていくという方法をとろうと思います。

準備

画像とラベルのリストはとりあえずCSVで準備しましょう。
こんな感じです。

image.csv
画像URL1,1
画像URL2,2
...

1カラム目に画像のURL、2カラム目に画像のラベル番号になってます

学習

では実際に学習に入っていきましょう!

まずは初期化から

import tensorflow as tf
import numpy as np
import httplib

## 定数
IMG_SIZE = 256 ## 学習させる画像の縦幅・横幅
IMG_LENGTH = IMG_SIZE * IMG_SIZE * 3 ## 学習させる画像データ長
LABEL_CNT = 10 ## ラベルの種類の数
IMG_DOMAIN = 'hogehoge.com' ## 画像が取得できるURLのドメイン名

## 学習に必要な変数の初期化
sess = tf.Session()
x = tf.placeholder(tf.float32, shape=[None, IMG_LENGTH])
W = tf.Variable(tf.zeros([IMG_LENGTH, LABEL_CNT]))
b = tf.Variable(tf.zeros([LABEL_CNT]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, shape=[None, LABEL_CNT])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

ここで『IMG_LENGTH』を画像の縦幅 x 横幅 x 3にしているのは理由があります。
今回学習で使用する画像はカラー画像なので1pxあたりのデータ量がRBGの3倍になります。
tf.placeholder()で学習時に後からデータを当てるための変数を確保していますが、第二引数の『shape=[...]』で指定したデータの形と一致しないと実行時にエラーになってしまいますので気をつけましょう。
※めちゃくちゃハマりました。。。

次にCSVファイルのデータをバッチ処理するための準備をします

## CSVファイルをワークキューとして設定
queue = tf.train.string_input_producer(['image.csv'])
reader = tf.TextLineReader()
key, val = reader.read(queue)
url, label = tf.decode_csv(val, [[''], [0]])

## バッチ処理の準備
batch_url, batch_label = tf.train.batch([url, label], batch_size=100)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

登録されたワークキューからバッチデータを取得します

try :
    while not coord.should_stop() :
        urls, labels = sess.run([batch_url, batch_label])

取得するデータはtf.train.batch()の『batch_size』で指定した数分のリストになっています

いよいよ画像をネットワーク越しに取得します!

        for url in urls :
            con = httplib.HTTPConnection(IMG_DOMAIN)
            con = con.request('GET', url)
            image = con.getresponse().read()

            ## 画像をTensorFlowで処理できるように変換
            tf.image.decode_jpeg(image, channels=3)
            image = tf.image.resize_image_with_crop_or_pad(image, IMG_SIZE, IMG_SIZE)
            image = tf.reshape(image, [-1])
            image = sess.run(image).astype(np.float32) / 255.0

・・・まぁ、大したことしてません。実際。
普通に画像をhttpで取ってくるだけです。
大事なのはどちらかというと画像を取得した後にTensorFlowで処理できるように変換するところです。

やってることは

  1. 取得する画像をまずはTensorFlowのデータ形式にデコード
  2. 画像のサイズを合わせます。
  3. このままだとデータの形が『縦 x 横 x RGB』の3次元配列になっているのでこれを1次元配列にして
  4. 最後にデータ型を浮動小数点に変換して終わりです。

この後はlabelも学習できる形式に変換して、実際に学習するだけです。
せっかくなのでコードも最後まで書いておきましょう。

        for label in labels :
            tmp = np.zeros(LABEL_CNT)
            tmp[label] = 1
            label = tmp
        sess.run(train_step, feed_dict={x: [image], y_: [label]})

    finally :
        correct_prediction = tf.equal(tf.argmax(self.y,1), tf.argmax(self.y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        sess.run(accuracy, feed_dict={x: [image], y_: [label]})
        coord.request_stop()
        coord.join(threads)

このコードではせっかくバッチで読み込んだデータを1行ずつしか学習してないので、まとめて学習させたい場合は適宜変えてもらえればいいかと。。。

まとめ

書いてみるとあっさり終わってる感がありますが、自分が機械学習に関してずぶの素人ということもありTensorFlow独自の癖(?)に終始振り回されてました。
実際画像の取得はなんてことなかったのですが、なかなかTensorFlowは食べてくれませんでした。
我が家の1歳児は水槽の中のグッピーからおもちゃまで何でもかんでも口に入れようとしてしまうので食べさせないようにするのに苦労しているというのに。。。この辺の貪欲さは幼児の圧勝ですねw

画像ってデータ量も大きいのにネットでは結構ローカルに学習データ持ってきてる情報が多くて実際皆さんどうされてるのか気になってます。

もしかしたらローカルに保持する理由もあるのかも。。
そのあたりに詳しい方がいたら是非教えてください!!

9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8