概要
tensorflowライブラリのmap_fn
という関数について紹介します.
map_fn
がどう動くのかを中心に書きます.
公式ドキュメントは「こちら」です.
内容
ある関数にテンソルの要素を一つ一つ与えたいときに使います.
mapは「写像」を意味していると思われます.
tf.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
elems
に入力テンソルを指定し,fn
に適用する関数を指定すると,elems
の要素が連続的にfn
に与えられます.
elems
はリスト,タプルにすることも可能ですが,最初の次元は一致させる必要があります.
1つのテンソル
簡単な例を次に示します.
import numpy as np
import tensorflow as tf
def func(x):
return x*x
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
elems = np.array([1,2,3,4,5,6])
op = tf.map_fn(func, elems)
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result.shape) # (6,)
print(result) # [ 1 4 9 16 25 36]
1つのテンソルに対して,関数が適用されていることが分かります.
しかし,これはfunc(elems)
のように,関数に引数を与えるときと変わりません.
map_fn
の挙動は次の例を見てもらえれば分かると思います.
import numpy as np
import tensorflow as tf
def func(x):
return 1
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
elems = np.array([1, 2, 3])
op = tf.map_fn(func, elems)
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result.shape) # (3,)
print(result) # [1 1 1]
1つのテンソルに対して,関数が適用されています.
テンソルは3つの要素を持っており,そのそれぞれに対して「1」が返り値として与えられ,結果としてarray([1, 1, 1])
がresult
に入ります.
複数のテンソル
2つのテンソルをタプルでelems
に渡したときの例です.
import numpy as np
import tensorflow as tf
def func(x):
return 5,6
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
elems = (np.array([1, 2, 3]), np.array([4,5,6]))
op = tf.map_fn(func, elems)
result = sess.run(op)
print(type(result)) # <class 'tuple'>
print(result) # (array([5, 5, 5]), array([6, 6, 6]))
2つのテンソルのそれぞれの要素に対し,順番に関数が適用され,1つ目のテンソルには「5」を,2つ目のテンソルには「6」を返しています.
返り値の形や型が,elems
と異なる場合は,dtype
で指定します.
import numpy as np
import tensorflow as tf
def func(x):
return x[0] * x[1]
def func2(x):
return x[0]+1, x[0]+2, x[1]+1, x[1]+2
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
elems = (np.array([1, 2, 3]), np.array([4,5,6]))
op = tf.map_fn(func, elems, dtype = tf.int32) # dtypeで指定
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result) # [ 4 10 18]
op2 = tf.map_fn(func2, elems, dtype = (tf.int32, tf.int32, tf.int32, tf.int32)) # dtypeで指定
result2 = sess.run(op2)
print(type(result2)) # <class 'tuple'>
print(result2) # (array([2, 3, 4]), array([3, 4, 5]), array([5, 6, 7]), array([6, 7, 8]))
おまけ
2次元テンソルの場合でも要素が一つずつ順に渡されていることが分かります.
import numpy as np
import tensorflow as tf
import sys
def func2(p, q):
return np.dot(p, q)
def func(x):
a, b = x
return func2(a,b)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
elems = (np.array([[1, 2, 3],[4,5,6]]), np.array([[10, 20, 30],[40,50,60]]))
op = tf.map_fn(func, elems, dtype = tf.int32)
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result) # [[ 10 40 90] [160 250 360]]