1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tf.map_fnの使い方

Posted at

概要

tensorflowライブラリのmap_fnという関数について紹介します.
map_fnがどう動くのかを中心に書きます.
公式ドキュメントは「こちら」です.

内容

ある関数にテンソルの要素を一つ一つ与えたいときに使います.
mapは「写像」を意味していると思われます.

.py
tf.map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=None,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

elemsに入力テンソルを指定し,fnに適用する関数を指定すると,elemsの要素が連続的にfnに与えられます.
elemsはリスト,タプルにすることも可能ですが,最初の次元は一致させる必要があります.

1つのテンソル

簡単な例を次に示します.

test.py
import numpy as np
import tensorflow as  tf

def func(x):
    return x*x

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

elems = np.array([1,2,3,4,5,6])
op = tf.map_fn(func, elems)
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result.shape) # (6,)
print(result) # [ 1  4  9 16 25 36]

1つのテンソルに対して,関数が適用されていることが分かります.
しかし,これはfunc(elems)のように,関数に引数を与えるときと変わりません.

map_fnの挙動は次の例を見てもらえれば分かると思います.

test2.py
import numpy as np
import tensorflow as  tf

def func(x):
    return 1

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

elems = np.array([1, 2, 3])
op = tf.map_fn(func, elems)
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result.shape) # (3,)
print(result) # [1 1 1]

1つのテンソルに対して,関数が適用されています.
テンソルは3つの要素を持っており,そのそれぞれに対して「1」が返り値として与えられ,結果としてarray([1, 1, 1])resultに入ります.

複数のテンソル

2つのテンソルをタプルでelemsに渡したときの例です.

test3.py
import numpy as np
import tensorflow as  tf

def func(x):
    return 5,6

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

elems = (np.array([1, 2, 3]), np.array([4,5,6]))
op = tf.map_fn(func, elems)
result = sess.run(op)
print(type(result)) # <class 'tuple'>
print(result) # (array([5, 5, 5]), array([6, 6, 6]))

2つのテンソルのそれぞれの要素に対し,順番に関数が適用され,1つ目のテンソルには「5」を,2つ目のテンソルには「6」を返しています.

返り値の形や型が,elemsと異なる場合は,dtypeで指定します.

test4.py
import numpy as np
import tensorflow as  tf

def func(x):
    return x[0] * x[1]

def func2(x):
    return x[0]+1, x[0]+2, x[1]+1, x[1]+2

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

elems = (np.array([1, 2, 3]), np.array([4,5,6]))

op = tf.map_fn(func, elems, dtype = tf.int32) # dtypeで指定
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result) # [ 4 10 18]

op2 = tf.map_fn(func2, elems, dtype = (tf.int32, tf.int32, tf.int32, tf.int32)) # dtypeで指定
result2 = sess.run(op2)
print(type(result2)) # <class 'tuple'>
print(result2) # (array([2, 3, 4]), array([3, 4, 5]), array([5, 6, 7]), array([6, 7, 8]))

おまけ

2次元テンソルの場合でも要素が一つずつ順に渡されていることが分かります.

test5.py
import numpy as np
import tensorflow as  tf
import sys

def func2(p, q):
    return np.dot(p, q)

def func(x):
    a, b = x
    return func2(a,b)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

elems = (np.array([[1, 2, 3],[4,5,6]]), np.array([[10, 20, 30],[40,50,60]]))
op = tf.map_fn(func, elems, dtype = tf.int32)
result = sess.run(op)
print(type(result)) # <class 'numpy.ndarray'>
print(result) # [[ 10  40  90] [160 250 360]]
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?