LoginSignup
0
1

More than 3 years have passed since last update.

tensorflow2 map_fn / scan の使い方

Last updated at Posted at 2020-12-20

この記事では tensorflow の map_fn / scan 関数の使用方法について解説します.

1. はじめに

チュートリアルを読めば使用方法は理解できますが, しばらく時間が経つとすぐに忘れてしまいそうなので備忘録的に記事としてまとめておきました.

2. 対象読者

さくっと map_fn / scan 関数の使い方を理解したい方.

3. 本編

3.1 map_fn

シーケンス入力の要素1つ1つに対して関数を作用させて出力シーケンスを作成します.

3.1.1 出力形状が入力形状と同じ(最も基本的)

  • 関数 or lambda 式で map する関数を指定して使います.
inputs = np.array([1,2,3,4])

# Define functions that process the elements of a sequence
def map_fn_1(x):
    return x + 1

outputs = tf.map_fn(fn=map_fn_1, elems=inputs)
print(outputs)
# >> tf.Tensor([2 3 4 5], shape=(4,), dtype=int64)

# We can also use lambda function
outputs = tf.map_fn(fn=lambda x: x + 1, elems=inputs)
print(outputs)
# >> tf.Tensor([2 3 4 5], shape=(4,), dtype=int64)

3.1.2 出力形状が入力形状と異なる場合

  • fn_output_signature で出力の型を指定する必要がある.
inputs = np.array([1,2,3,4])

def map_fn_2(x):
    return x, x + 1

# outputs = tf.map_fn(fn=map_fn_2, elems=inputs)
# >> ValueError: The two structures don't have the same nested structure.

# Specify output using fn_output_signature
outputs = tf.map_fn(fn=map_fn_2,
                    elems=inputs,
                    fn_output_signature=(tf.int64, tf.int64))
print(outputs)  
# >> (<tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 2, 3, 4])>,
#     <tf.Tensor: shape=(4,), dtype=int64, numpy=array([2, 3, 4, 5])>)

3.1.3 tuple 入力 / 異なる形状で出力

  • loop 方向の要素数が揃っていれば tuple で入力できる.
inputs1 = np.array([1,2,3,4,5])
inputs2 = np.zeros((5,3), dtype=np.int64)
print(inputs1.shape)    # (5,)
print(inputs2.shape)    # (5, 3)

def map_fn_3(args):
    _input1 = args[0]
    _input2 = args[1]
    return _input2 + _input1

outputs = tf.map_fn(fn=map_fn_3,
                    elems=(inputs1, inputs2),
                    fn_output_signature=(tf.int64))
print(outputs)
# >> tf.Tensor(
#    [[1 1 1]
#     [2 2 2]
#     [3 3 3]
#     [4 4 4]
#     [5 5 5]], shape=(5, 3), dtype=int64)

3.2 scan

scan 関数では累積的な演算を行う.

3.2.1 出力形状が入力形状と同じ(最も基本的)

inputs = np.array([1,2,3,4])

# Define functions that process the elements of a sequence
def scan_fn_1(accumulate, x):
    return accumulate + x

outputs = tf.scan(fn=scan_fn_1,
                  elems=inputs)
print(outputs)
# >> tf.Tensor([ 1  3  6 10], shape=(4,), dtype=int64)
# The accumulate value is initialized by fn output type.

# We can also use lambda function
outputs = tf.scan(fn=lambda a, x: a + x, elems=inputs)
print(outputs)
# >> tf.Tensor([ 1  3  6 10], shape=(4,), dtype=int64)

3.2.2 出力形状が入力形状と異なる場合

  • initializer で出力要素を初期化する必要がある
inputs = np.array([1,2,3,4])

def scan_fn_2(accumulate, x):
    return accumulate[0] + x, accumulate[1] + x + 1

# outputs = tf.scan(fn=scan_fn_2, elems=inputs)
# >> ValueError: The two structures don't have the same nested structure.

# Specify output using fn_output_signature
outputs = tf.scan(fn=scan_fn_2,
                  elems=inputs,
                  initializer=(np.array(0), np.array(0)))
print(outputs)  
# >> (<tf.Tensor: shape=(4,), dtype=int64, numpy=array([ 1,  3,  6, 10])>,
#     <tf.Tensor: shape=(4,), dtype=int64, numpy=array([ 2,  5,  9, 14])>)

3.2.3 tuple 入力 / 異なる形状で出力

  • loop 方向の要素数が揃っていれば tuple で入力できる.
inputs1 = np.array([1,2,3,4,5])
inputs2 = np.zeros((5,3), dtype=np.int64)
print(inputs1.shape)    # (5,)
print(inputs2.shape)    # (5, 3)

def scan_fn_3(accumulate, args):
    _input1 = args[0]
    _input2 = args[1]
    return accumulate + _input2 + _input1

outputs = tf.scan(fn=scan_fn_3,
                  elems=(inputs1, inputs2),
                  initializer=np.array([[0, 0, 0]]))
print(outputs)
# >> tf.Tensor(
#    [[[ 1  1  1]]
#     [[ 3  3  3]]
#     [[ 6  6  6]]
#     [[10 10 10]]
#     [[15 15 15]]], shape=(5, 1, 3), dtype=int64)

Gist Link

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1