※ 本記事の内容はTensorFlow 1.7.1で動作を確認しました.
tensorflow.contrib.framework.nest
では,ネストされたリストやタプルを1次元のリストに均したり,リストやタプルを任意のネストに変換する関数が提供されています.
特に,テンソルのネストが統一されなければいけない関数の呼び出しで効果を発揮するモジュールです.
本記事では,nest
の中でもnest.flatten
とnest.pack_sequence_as
に絞って紹介します.
それぞれの使い方の紹介と,例としてRNNの入力に使うstate
のパックを取り上げます.
#nest.flatten(x)
nest.flatten(x)
は任意のx
を1次元のリストに均した値を返します.
from tensorflow.contrib.framework import nest
x = ([1, 2, 3], [[4, 5], 6, [7, 8]], 9)
flat_x = nest.flatten(x)
print(flat_x) # -> [1, 2, 3, 4, 5, 6, 7, 8, 9]
#nest.pack_sequence_as(structure, flat_sequence)
nest.pack_sequence_as(structure, flat_sequence)
ではリストflat_sequence
を任意のネストstructure
に変換することができます.
structure
とflat_sequence
の各要素は同じである必要はありません.
from tensorflow.contrib.framework import nest
structure = ([1, 2, 3], [[4, 5], 6, [7, 8]], 9)
flat_sequence = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
sequence = nest.pack_sequence_as(structure, flat_sequence)
print(sequence) # -> (["a", "b", "c"], [["d", "e"], "f", ["g", "h"]], "i")
#使用例:RNNのstate
をパックする
上ではintのflatten/packを行いましたが,tensorflow.Tensor
に対してもこれらを適用できます.
本記事では,3層のLSTM-RNNを用意し入力引数のstate
をplaceholderにする例を紹介します.
nest
をplaceholderのベタ書きを回避できる上,feed_dict
の更新にも再利用できます.
##準備: RNNCellの定義
まず,3層のLSTM-RNNを定義します.
CELL_UNITS = 10
cells = [rnn_cell.LSTMCell(CELL_UNITS) for _ in range(3)]
rnn = rnn_cell.MultiRNNCell(cells)
##state
の用意
MultiRNNCellで前向き計算を実行するには,入力引数としてinputs
とstate
が必要です.
このstate
はrnn.state_size
と同じネストでなければいけません.
今,rnn.state_size
は
(LSTMStateTuple(c=10, h=10),
LSTMStateTuple(c=10, h=10),
LSTMStateTuple(c=10, h=10))
となっているので,state
をplaceholderで用意する場合は,
(LSTMStateTuple(c=PH1, h=PH2),
LSTMStateTuple(c=PH3, h=PH4),
LSTMStateTuple(c=PH5, h=PH6))
のようにする必要があります.
3層程度であればベタ書きでも構いませんが,より層が深くなる場合などを考えると現実的ではありません.
##nest
を利用したstate
のパック
ベタ書きを回避する方法として,nest
を使っていきます.
まず,(1) nest.flatten
でrnn.state_size
を均し,(2) placeholderを用意したあと,(3) 最後にnest.pack_sequence_as
でrnn.state_size
と同じネストにしてあげます.
# (1) state_sizeを均す
flat_state_size = nest.flatten(rnn.state_size)
print(flat_state_size) # -> [10, 10, 10, 10, 10, 10]
# (2) リストにplaceholderを追加していく
flat_state_ph = []
for size in flat_state_size:
ph = tf.placeholder(tf.float32, (BATCH_SIZE, size))
flat_state_ph.append(ph)
# (3) placeholderのリストをstate_sizeと同じネストにパックする
state_ph = nest.pack_sequence_as(rnn.state_size, flat_state_ph)
これで,rnn.state_size
と同じネストを持つplaceholderのstate
が定義できます.
##nest
を再利用してfeed_dict
を更新する
上のように,nest
を使ってplaceholderのstate
を定義すると,その過程で,均したstate_size
(上のflat_state_size
)やplaceholderのリスト(上のflat_state_ph
)を用意することになります.
このflat_state_size
やflat_state_ph
はfeed_dict
を更新する際にも使用できます.
# feed_dict の中身を用意,flat_state_sizeを再利用
flat_state_values = []
for size in flat_state_size:
state_value = np.random.normal((BATCH_SIZE, size)).astype(np.float32)
flat_state_values.append(state_value)
# feed_dictの更新,flat_state_phを再利用
feed_dict.update(zip(flat_state_ph, flat_state_values))
##コード全文
上の例を1つのコードにまとめました.ここでは適当に,入力に乱数を使用しています.
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework import nest
NUM_CELL_UNITS = 10
TIME_LENGTH = 15
BATCH_SIZE = 20
INPUT_UNITS = 5
rnn_cell = tf.nn.rnn_cell
# 前向き計算
def forward_rnn(rnn, inputs, state):
outputs = []
scope = tf.get_variable_scope()
with tf.variable_scope(scope):
for cur_input in inputs:
cur_output, state = rnn(cur_input, state)
outputs.append(cur_output)
return outputs, state
# 3層の LSTM-RNN を定義
cells = [rnn_cell.LSTMCell(NUM_CELL_UNITS) for _ in range(3)]
rnn = rnn_cell.MultiRNNCell(cells)
inputs_ph = [tf.placeholder(tf.float32, (BATCH_SIZE, INPUT_UNITS))
for _ in range(TIME_LENGTH)]
# state の placeholder を定義
# (1) state_sizeを均す
flat_state_size = nest.flatten(rnn.state_size)
print(flat_state_size) # -> [10, 10, 10, 10, 10, 10]
# (2) リストにplaceholderを追加していく
flat_state_ph = []
for size in flat_state_size:
ph = tf.placeholder(tf.float32, (BATCH_SIZE, size))
flat_state_ph.append(ph)
# (3) placeholderのリストをstate_sizeと同じネストにパックする
state_ph = nest.pack_sequence_as(rnn.state_size, flat_state_ph)
print("=== state ===")
print(state_ph)
# 前向き計算の実行
outputs, new_state = forward_rnn(rnn, inputs_ph, state_ph)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
feed_dict = {}
# inputs_ph に乱数を入れる
for i in range(TIME_LENGTH):
input_value = np.random.normal(size=(BATCH_SIZE, INPUT_UNITS)).astype(np.float32)
feed_dict[inputs_ph[i]] = input_value
# feed_dict の中身を用意,flat_state_sizeを再利用
flat_state_values = []
for size in flat_state_size:
state_value = np.random.normal(size=(BATCH_SIZE, size)).astype(np.float32)
flat_state_values.append(state_value)
# feed_dictの更新,flat_state_phを再利用
feed_dict.update(zip(flat_state_ph, flat_state_values))
outputs_value, new_state_value = sess.run([outputs, new_state],
feed_dict=feed_dict)
print("=== outputs ===")
print(outputs)
print("=== new_state === ")
print(new_state_value)