2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowのnestモジュール

Last updated at Posted at 2018-07-05

※ 本記事の内容はTensorFlow 1.7.1で動作を確認しました.

tensorflow.contrib.framework.nestでは,ネストされたリストやタプルを1次元のリストに均したり,リストやタプルを任意のネストに変換する関数が提供されています.
特に,テンソルのネストが統一されなければいけない関数の呼び出しで効果を発揮するモジュールです.

本記事では,nestの中でもnest.flattennest.pack_sequence_asに絞って紹介します.
それぞれの使い方の紹介と,例としてRNNの入力に使うstateのパックを取り上げます.

#nest.flatten(x)
nest.flatten(x)は任意のxを1次元のリストに均した値を返します.

example_flatten.py
from tensorflow.contrib.framework import nest

x = ([1, 2, 3], [[4, 5], 6, [7, 8]], 9)
flat_x = nest.flatten(x)

print(flat_x) # -> [1, 2, 3, 4, 5, 6, 7, 8, 9]

#nest.pack_sequence_as(structure, flat_sequence)
nest.pack_sequence_as(structure, flat_sequence)ではリストflat_sequenceを任意のネストstructureに変換することができます.
structureflat_sequenceの各要素は同じである必要はありません.

example_pack_sequence_as.py
from tensorflow.contrib.framework import nest

structure = ([1, 2, 3], [[4, 5], 6, [7, 8]], 9)
flat_sequence = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
sequence = nest.pack_sequence_as(structure, flat_sequence)

print(sequence) # -> (["a", "b", "c"], [["d", "e"], "f", ["g", "h"]], "i")

#使用例:RNNのstateをパックする
上ではintのflatten/packを行いましたが,tensorflow.Tensorに対してもこれらを適用できます.
本記事では,3層のLSTM-RNNを用意し入力引数のstateをplaceholderにする例を紹介します.
nestをplaceholderのベタ書きを回避できる上,feed_dictの更新にも再利用できます.

##準備: RNNCellの定義
まず,3層のLSTM-RNNを定義します.

CELL_UNITS = 10
cells = [rnn_cell.LSTMCell(CELL_UNITS) for _ in range(3)]
rnn = rnn_cell.MultiRNNCell(cells)

##stateの用意
MultiRNNCellで前向き計算を実行するには,入力引数としてinputsstateが必要です.
このstaternn.state_sizeと同じネストでなければいけません.
今,rnn.state_size

(LSTMStateTuple(c=10, h=10), 
 LSTMStateTuple(c=10, h=10),
 LSTMStateTuple(c=10, h=10))

となっているので,stateをplaceholderで用意する場合は,

(LSTMStateTuple(c=PH1, h=PH2),
 LSTMStateTuple(c=PH3, h=PH4),
 LSTMStateTuple(c=PH5, h=PH6))

のようにする必要があります.
3層程度であればベタ書きでも構いませんが,より層が深くなる場合などを考えると現実的ではありません.

##nestを利用したstateのパック
ベタ書きを回避する方法として,nestを使っていきます.
まず,(1) nest.flattenrnn.state_sizeを均し,(2) placeholderを用意したあと,(3) 最後にnest.pack_sequence_asrnn.state_sizeと同じネストにしてあげます.

# (1) state_sizeを均す
flat_state_size = nest.flatten(rnn.state_size)
print(flat_state_size) # -> [10, 10, 10, 10, 10, 10]

# (2) リストにplaceholderを追加していく
flat_state_ph = []
for size in flat_state_size:
    ph = tf.placeholder(tf.float32, (BATCH_SIZE, size))
    flat_state_ph.append(ph)

# (3) placeholderのリストをstate_sizeと同じネストにパックする
state_ph = nest.pack_sequence_as(rnn.state_size, flat_state_ph)

これで,rnn.state_sizeと同じネストを持つplaceholderのstateが定義できます.

##nestを再利用してfeed_dictを更新する
上のように,nestを使ってplaceholderのstateを定義すると,その過程で,均したstate_size(上のflat_state_size)やplaceholderのリスト(上のflat_state_ph)を用意することになります.
このflat_state_sizeflat_state_phfeed_dictを更新する際にも使用できます.

# feed_dict の中身を用意,flat_state_sizeを再利用
flat_state_values = []
for size in flat_state_size:
    state_value = np.random.normal((BATCH_SIZE, size)).astype(np.float32)
    flat_state_values.append(state_value)

# feed_dictの更新,flat_state_phを再利用
feed_dict.update(zip(flat_state_ph, flat_state_values))

##コード全文
上の例を1つのコードにまとめました.ここでは適当に,入力に乱数を使用しています.

# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework import nest

NUM_CELL_UNITS = 10
TIME_LENGTH = 15
BATCH_SIZE = 20
INPUT_UNITS = 5

rnn_cell = tf.nn.rnn_cell

# 前向き計算
def forward_rnn(rnn, inputs, state):
    outputs = []
    scope = tf.get_variable_scope()
    with tf.variable_scope(scope):
        for cur_input in inputs:
            cur_output, state = rnn(cur_input, state) 
            outputs.append(cur_output)
    return outputs, state

# 3層の LSTM-RNN を定義
cells = [rnn_cell.LSTMCell(NUM_CELL_UNITS) for _ in range(3)]
rnn = rnn_cell.MultiRNNCell(cells)

inputs_ph = [tf.placeholder(tf.float32, (BATCH_SIZE, INPUT_UNITS))
             for _ in range(TIME_LENGTH)]

# state の placeholder を定義
# (1) state_sizeを均す
flat_state_size = nest.flatten(rnn.state_size)
print(flat_state_size) # -> [10, 10, 10, 10, 10, 10]
# (2) リストにplaceholderを追加していく
flat_state_ph = []
for size in flat_state_size:
    ph = tf.placeholder(tf.float32, (BATCH_SIZE, size))
    flat_state_ph.append(ph)
# (3) placeholderのリストをstate_sizeと同じネストにパックする
state_ph = nest.pack_sequence_as(rnn.state_size, flat_state_ph)
print("=== state ===")
print(state_ph)

# 前向き計算の実行
outputs, new_state = forward_rnn(rnn, inputs_ph, state_ph)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

feed_dict = {}

# inputs_ph に乱数を入れる
for i in range(TIME_LENGTH):
    input_value = np.random.normal(size=(BATCH_SIZE, INPUT_UNITS)).astype(np.float32)
    feed_dict[inputs_ph[i]] = input_value

# feed_dict の中身を用意,flat_state_sizeを再利用
flat_state_values = []
for size in flat_state_size:
    state_value = np.random.normal(size=(BATCH_SIZE, size)).astype(np.float32)
    flat_state_values.append(state_value)
# feed_dictの更新,flat_state_phを再利用
feed_dict.update(zip(flat_state_ph, flat_state_values))

outputs_value, new_state_value = sess.run([outputs, new_state], 
                                          feed_dict=feed_dict)

print("=== outputs ===")
print(outputs)

print("=== new_state === ")
print(new_state_value)
2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?