Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
0
Help us understand the problem. What is going on with this article?
@exy81

Tensorflowで指定したインデックスのテンソルを抜き出す

More than 1 year has passed since last update.

どういうことをしたいか

行列から各行から異なる列の値を取り出したい(パターンA)

A = (a_{ij}) \in \mathbb{R}^{N\times M} \\
index = b_i \in \{1,\cdots,M\}^N \\
A[index] = a_{ib_i}

もしくは逆に各列から異なる行の値を取り出したい(パターンB)

A = (a_{ij}) \in \mathbb{R}^{N\times M} \\
index = b_j \in \{1,\cdots,N\}^M \\
A[index] = a_{b_j j}

具体例(パターンA)

A=\left(\begin{matrix}
  1 & 2 & 3 \\
  4 & 5 & 6 \\
  7 & 8 & 9
\end{matrix}\right) \\

index=\left(\begin{matrix} 2 & 1 & 0 \end{matrix}\right) \\

A[index] = \left(\begin{matrix} 3 & 5 & 7 \end{matrix}\right)

TensorFlowでの実装

以下のコードで行列AからIDXで指定された値を抜き出せる。

パターンA用

# A : 行列
# IDX : index (dtype=tf.int32, int32でない場合はキャストが必要になることがある)
_IDX = tf.concat([tf.range(A.shape[0])[:,tf.newaxis], IDX[:,tf.newaxis]], axis=1)
subA = tf.gather_nd(A, _IDX)

パターンB用

# A : 行列
# IDX : index
_IDX = tf.concat(IDX[:,tf.newaxis], [tf.range(A.shape[1])[:,tf.newaxis]], axis=1)
subA = tf.gather_nd(A, _IDX)

簡単な解説

gather_ndは第一引数のテンソルから第二引数で指定された座標の値を抜き出して返す関数。
tf.rangeで生成したテンソルとindexをconcatすることで、抜き出したい座標を生成し、それをgather_ndに渡すことで各行(or 列)からindexで指定された値を抜き出すことを実現している。

0
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
exy81
しばらくZennとQiitaに同じ記事を投稿します。

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
0
Help us understand the problem. What is going on with this article?