#tf.gather
と仲間たち
まず、gatheの処理をしてくれる仲間たちを紹介する。
-
tf.gather
指定したaxisに沿ってスライスしたものをgatherする。 -
tf.gather_nd
自由度が高く、任意の要素からgatherされる。 -
tf.batch_gather
batchに沿って処理される。indicesでスライスされる。
今回はバッチごとにgather処理をしてくれるtf.batch_gather
を使ってみた。
#tf.batch_gather
の使い方
##tf.batch_gather
の解説
tf.batch_gather(
params,
indices,
name=None
)
prams(Tensor): [A1, ..., AN-1, AN, B1, ..., BM]
indices(Tensor): [A1, ..., AN-1, C]
result(Tensor): [A1, ..., AN-1, C, B1, ..., BM]
- ANでスライスされ、A1〜AN-1の構造はそのまま維持される。
- CにANからgatherされる値が入る。
- Cの各値はANの大きさ以下でなければいけない。
詳しくはTensorFlowの公式ドキュメントを参考にしてもらいたい。
##今回の処理
点群屋さんなので、点群を例に説明します。
8点から成る3次元点群を用意した。ここから4点を選びgatherするような処理を行う。
Bはバッチ、Nは点群の点、Cはチャンネル(3次元点群なのでxyz)
##サンプルコード
batch_gather.py
import tensorflow as tf
#元の点群(2batch x 8点 x 3次元)
param = tf.constant([[[0.,0.,9.],[0.,1.,9.], [0.,2.,9.], [0.,3.,9.], [0.,4.,9.], [0.,5.,9.], [0.,6.,9.], [0.,7.,9.]],
[[1.,0.,9.],[1.,1.,9.], [1.,2.,9.], [1.,3.,9.], [1.,4.,9.], [1.,5.,9.], [1.,6.,9.], [1.,7.,9.]]]) # B x N x C(2 x 8 x 3)
print("param shape: ",param.shape)
#gather元を選ぶテンソル(2batch x 4点)
indices = tf.constant([[1,0,0,4],
[1,3,4,6]]) # B x N
print("indices shape: ",indices.shape)
#gatherする
result = tf.batch_gather(param, indices)
print("result shape: ",result.shape)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("param\n",sess.run(param)) #入力の確認
print("indices\n",sess.run(indices)) #indiciesの確認
print("result\n",sess.run(result)) #gatherの結果
実行結果
param shape: (2, 8, 3)
indices shape: (2, 4)
result shape: (2, 4, 3)
param
[[[0. 0. 9.]
[0. 1. 9.]
[0. 2. 9.]
[0. 3. 9.]
[0. 4. 9.]
[0. 5. 9.]
[0. 6. 9.]
[0. 7. 9.]]
[[1. 0. 9.]
[1. 1. 9.]
[1. 2. 9.]
[1. 3. 9.]
[1. 4. 9.]
[1. 5. 9.]
[1. 6. 9.]
[1. 7. 9.]]]
indices
[[1 0 0 4]
[1 3 4 6]]
result
[[[0. 1. 9.]
[0. 0. 9.]
[0. 0. 9.]
[0. 4. 9.]]
[[1. 1. 9.]
[1. 3. 9.]
[1. 4. 9.]
[1. 6. 9.]]]