LoginSignup
0
0

More than 3 years have passed since last update.

Tensorflowで指定したインデックスのテンソルを抜き出す

Posted at

どういうことをしたいか

行列から各行から異なる列の値を取り出したい(パターンA)

A = (a_{ij}) \in \mathbb{R}^{N\times M} \\
index = b_i \in \{1,\cdots,M\}^N \\
A[index] = a_{ib_i}

もしくは逆に各列から異なる行の値を取り出したい(パターンB)

A = (a_{ij}) \in \mathbb{R}^{N\times M} \\
index = b_j \in \{1,\cdots,N\}^M \\
A[index] = a_{b_j j}

具体例(パターンA)

A=\left(\begin{matrix}
  1 & 2 & 3 \\
  4 & 5 & 6 \\
  7 & 8 & 9
\end{matrix}\right) \\

index=\left(\begin{matrix} 2 & 1 & 0 \end{matrix}\right) \\

A[index] = \left(\begin{matrix} 3 & 5 & 7 \end{matrix}\right)

TensorFlowでの実装

以下のコードで行列AからIDXで指定された値を抜き出せる。

パターンA用

# A : 行列
# IDX : index (dtype=tf.int32, int32でない場合はキャストが必要になることがある)
_IDX = tf.concat([tf.range(A.shape[0])[:,tf.newaxis], IDX[:,tf.newaxis]], axis=1)
subA = tf.gather_nd(A, _IDX)

パターンB用

# A : 行列
# IDX : index
_IDX = tf.concat(IDX[:,tf.newaxis], [tf.range(A.shape[1])[:,tf.newaxis]], axis=1)
subA = tf.gather_nd(A, _IDX)

簡単な解説

gather_ndは第一引数のテンソルから第二引数で指定された座標の値を抜き出して返す関数。
tf.rangeで生成したテンソルとindexをconcatすることで、抜き出したい座標を生成し、それをgather_ndに渡すことで各行(or 列)からindexで指定された値を抜き出すことを実現している。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0