#TensorFlowの使い方
前回の動作の仕組み編で説明したように、TensorFlowは完全に計算をバックエンドに任せるために計算グラフを構築する。
今回はその計算グラフの構築の仕方と計算開始方法について説明する。
#tf.Session
TensorFlowのセッションオブジェクトを返す関数。これがないとTensorFlowで計算できない。
このオブジェクトは後でclose()することが望ましいとされているため、一般的にtf.Session()はwith文と一緒に宣言されることが多い。
また、sessという名前の変数に格納される場合が多い。
with tf.Session() as sess:
###各種処理###
#sess.run
tfで計算を開始できる関数。構築した計算グラフを元に実際に計算ができる。
sess.run(calc_graph, feed_dict={x:train_x,y:train_y})
calc_graphに構築した計算グラフを、feed_dictにはplace_holderへの入力を格納する。
返り値は計算グラフの計算結果であり、いわゆるニューラルネットワークであればそれの計算結果が返ってくる。
#計算グラフ系関数
実際にTensorFlowで計算グラフを構築する際に必要になる関数群。
##tf.constant
計算グラフで定数ノードを定義したい場合に使う関数。
const = tf.constant(1)
print(const)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(value))
上記の例は定数1のノードを計算グラフに定義している。
もちろんこれは定数を確保するという計算グラフを構築しただけであり、
print(const)しても「1」が出力されず、「Tensor("Const:0", shape=(), dtype=int32)」と出力される。
あくまでも計算グラフ上でどういうノードが存在しているのかを示しているだけである。
##tf.Variable
計算グラフにおける変数ノードを定義する関数。
ニューラルネットワークの重みなんかはこれで定義されているらしい。
value = tf.Variable(1)
print(value)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(value))
##tf.assign
計算グラフにおける代入処理を行うノード。どんなふうに使うか例を挙げてみる。
value = tf.Variable(0)
const = tf.constant(1)
add = value + const
assign = tf.assign(value, add)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(10):
print(sess.run(assign))
上記のコードは0に1を足す計算グラフを10回繰り返すという簡単な計算をコードである。
計算した結果をvalueに格納するために毎回assignで代入していることが分かる。
##tf.placeholder
何が入るか具体的な値は分からないが、型と配列サイズなどを指定して値を格納するノードだけ定義する関数。
計算グラフを関数に例えるなら引数みたいな役割をこなす。
placeholderで確保された箱はsess.runをする際にfeed_dictで受け取る必要がある。
value = tf.placeholder(tf.int32)
array = tf.placeholder(tf.int32,[5])
multiple = value * array
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(multiple,feed_dict = {value:3,array:[1,2,3,4,5]}))
上記のコードのようにただの定数のみのノードを用意することもできるし、大きさを指定した配列のノードも用意できる。
ちなみに以下のように型の後ろの配列をNoneにするとその次元数で受け取れるサイズの大きさが任意になる。
array = tf.placeholder(tf.int32,[None])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(array,feed_dict = {array:[1,2,3]}))
print(sess.run(array,feed_dict = {array:[1,2,3,4,5]}))
##その他関数メモ
基本的には上に書かれた関数を分かってればとりあえずTensorFlow入門は完了してそう。
後は自分で構築したい計算グラフの処理に応じて検索掛ければ多分色々出てくる。
英語だけど公式リファレンス見てもいい。
とりあえずなんとなく調べてて目についたその他の関数。
- tf.cond
- tf.while_loop
- tf.cast
- tf.shape
- tf.unstack
- tf.one_hot
- tf.reshape