どういうことをしたいか
行列から各行から異なる列の値を取り出したい(パターンA)
A = (a_{ij}) \in \mathbb{R}^{N\times M} \\
index = b_i \in \{1,\cdots,M\}^N \\
A[index] = a_{ib_i}
もしくは逆に各列から異なる行の値を取り出したい(パターンB)
A = (a_{ij}) \in \mathbb{R}^{N\times M} \\
index = b_j \in \{1,\cdots,N\}^M \\
A[index] = a_{b_j j}
具体例(パターンA)
A=\left(\begin{matrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9
\end{matrix}\right) \\
index=\left(\begin{matrix} 2 & 1 & 0 \end{matrix}\right) \\
A[index] = \left(\begin{matrix} 3 & 5 & 7 \end{matrix}\right)
TensorFlowでの実装
以下のコードで行列AからIDXで指定された値を抜き出せる。
パターンA用
# A : 行列
# IDX : index (dtype=tf.int32, int32でない場合はキャストが必要になることがある)
_IDX = tf.concat([tf.range(A.shape[0])[:,tf.newaxis], IDX[:,tf.newaxis]], axis=1)
subA = tf.gather_nd(A, _IDX)
パターンB用
# A : 行列
# IDX : index
_IDX = tf.concat(IDX[:,tf.newaxis], [tf.range(A.shape[1])[:,tf.newaxis]], axis=1)
subA = tf.gather_nd(A, _IDX)
簡単な解説
gather_ndは第一引数のテンソルから第二引数で指定された座標の値を抜き出して返す関数。
tf.rangeで生成したテンソルとindexをconcatすることで、抜き出したい座標を生成し、それをgather_ndに渡すことで各行(or 列)からindexで指定された値を抜き出すことを実現している。