Help us understand the problem. What is going on with this article?

TensorFlowの基本的な関数をXORゲートを学習するニューラルネットワーク作成してまとめてみる

More than 1 year has passed since last update.

TensorFlowの関数についてよくわからなかったのでこの本を見てまとめてみた。

xorゲートをTensorFlowを使って実現する。
真理値表は以下、入力層の次元が2 出力層の次元が1となる

x1 x2 y
0 0 0
0 1 1
1 0 1
1 1 0

プログラムの流れ

①ライブラリのインポート

import numpy as np
import tensorflow as tf

②XORのデータ用意

X = np.array([[0,0],[0,1],[1,0],[1,1]])
Y = np.array([[0],[1],[1],[0]])

③入力と正解ラベルの入れ物用意

x = tf.placeholder(tf.float32, shape=[None,2])
t = tf.placeholder(tf.float32, shape=[None,1])

tf.placeholder()
データを格納する入れ物のような存在.
モデル定義の際には次元だけを決めておき、モデルの学習など実際にデータが必要になったタイミングで値を入れて実際の式を評価することを可能にする
shape=[None,2]は入力ベクトルの次元が2であることを表し、Noneとしているのはデータ数が可変でも対応できる入れ物となる
None部分 ← 要するにxorゲートの際は00,01,10,11の4つのデータであるが、実際にはデータ数がわからないことがあるのでNoneとなる

④モデルの定義(入力層 - 隠れ層)

x:入力  h:隠れ層の出力  W:重み  b:バイアス

h = Wx + b
W = tf.Variable(tf.truncated_normal([2,2]))
b = tf.Variable(tf.zeros([2]))
h = tf.nn.sigmoid(tf.matmul(x,W) + b)

tf.Variable()
変数を生成するのに必要。TensorFlowが持つ独自の型でデータを扱っていく
中身のtf.zeros()はNumpyにおけるnp.zeros()に相当する
tf.truncated_normal()は切断正規分布に従うデータを生成するメソッド。0で初期化すると正しく誤差が反映されなくなる恐れがあるため

⑤モデルの定義(隠れ層 - 出力層)

h:出力層への入力(隠れ層出力)  y:出力  V:重み  c:バイアス

y = Vh + c
V = tf.Variable(tf.truncated_normal([2,1]))
c = tf.Variable(tf.zeros([1]))
y = tf.nn.sigmoid(tf.matmul(h,V) + c)

説明は④と同じ

⑥誤差関数

cross_entropy = -tf.reduce_sum(t * tf.log(y) + (1-t) * tf.log(1-y))

今回は2値分類であるため、交差エントロピー関数を使用する

-tf.reduce_sum(t * tf.log(y) + (1-t) * tf.log(1-y))
交差エントロピー誤差関数の計算を数式通りに書くことができる
tf.reduce_sum()np.sum()に対応する

⑦確率的勾配降下法

train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy)

確率的勾配降下法を適用している
GradientDescentOptimizer()の引数0.1は学習率

⑧ 学習後の結果の確認

correct_prediction = tf.equal(tf.to_float(tf.greater(y, 0.5)), t)

学習後の結果が正しいかどうか確認するための実装
y >= 0.5 でニューロンが発火する。それを正解ラベルのtと比較してTrue or Falseを返す

⑨セッション用意

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

TensorFlowでは必ずセッションというデータのやり取りの流れの中で計算が行われます。
ここではじめてモデルの定義で宣言した変数の・式の初期化が行われます。

⑩学習

for epoch in range(4000):
    sess.run(train_step, feed_dict={
        x:X,
        t:Y
    })

    if epoch % 1000 == 0:
        print('epoch:', epoch)

sess.run(train_step)これは勾配降下法による学習をすること
feed_dictplaceholderであるx,tに値を代入している
まさにplaceholderに値をfeed

⑪学習結果の確認(正解ラベルとの比較)

classified = correct_prediction.eval(session=sess, feed_dict={
    x:X,
    t:Y
})

eval()
ニューロンが発火する・しないを適切に分類できるようになっているかを確認するのに使う
要するにここでは、correct_predictionの値の確認に使用する

⑫学習結果の確認(出力確率)

prob = y.eval(session=sess, feed_dict={
    x:X,
    t:Y
})

各入力に対する出力確率を得ることができる
要するにyの値を確認できる

⑫表示

print('classified:')
print(classified)
print()
print('output probability:')
print(prob)

結果

出力
epoch: 0
epoch: 1000
epoch: 2000
epoch: 3000
classified:
[[ True]
 [ True]
 [ True]
 [ True]]

output probability:
[[ 0.00661706]
 [ 0.99109781]
 [ 0.99389231]
 [ 0.00563505]]

参考

詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理

taigamikami
大学生 自分の勉強・メモとしてQiitaに投稿しています。おかしいと思う部分は遠慮なくご指摘いただければと思います。 Ruby/Rails/Swift/iOS/Python
https://taigamikami.github.io/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした