tf.batch_gatherの使い方メモ


tf.gatherと仲間たち

まず、gatheの処理をしてくれる仲間たちを紹介する。

- tf.gather 指定したaxisに沿ってスライスしたものをgatherする。

- tf.gather_nd 自由度が高く、任意の要素からgatherされる。

- tf.batch_gather batchに沿って処理される。indicesでスライスされる。

今回はバッチごとにgather処理をしてくれるtf.batch_gatherを使ってみた。


tf.batch_gatherの使い方


tf.batch_gatherの解説

tf.batch_gather(

params,
indices,
name=None
)

prams(Tensor): [A1, ..., AN-1, AN, B1, ..., BM]

indices(Tensor): [A1, ..., AN-1, C]

result(Tensor): [A1, ..., AN-1, C, B1, ..., BM]


  • ANでスライスされ、A1〜AN-1の構造はそのまま維持される。

  • CにANからgatherされる値が入る。

  • Cの各値はANの大きさ以下でなければいけない。

詳しくはTensorFlowの公式ドキュメントを参考にしてもらいたい。


今回の処理

点群屋さんなので、点群を例に説明します。

8点から成る3次元点群を用意した。ここから4点を選びgatherするような処理を行う。

Bはバッチ、Nは点群の点、Cはチャンネル(3次元点群なのでxyz)

gather.jpg


サンプルコード


batch_gather.py

import tensorflow as tf

#元の点群(2batch x 8点 x 3次元)
param = tf.constant([[[0.,0.,9.],[0.,1.,9.], [0.,2.,9.], [0.,3.,9.], [0.,4.,9.], [0.,5.,9.], [0.,6.,9.], [0.,7.,9.]],
[[1.,0.,9.],[1.,1.,9.], [1.,2.,9.], [1.,3.,9.], [1.,4.,9.], [1.,5.,9.], [1.,6.,9.], [1.,7.,9.]]]) # B x N x C(2 x 8 x 3)
print("param shape: ",param.shape)

#gather元を選ぶテンソル(2batch x 4点)
indices = tf.constant([[1,0,0,4],
[1,3,4,6]]) # B x N
print("indices shape: ",indices.shape)

#gatherする
result = tf.batch_gather(param, indices)
print("result shape: ",result.shape)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("param\n",sess.run(param)) #入力の確認
print("indices\n",sess.run(indices)) #indiciesの確認
print("result\n",sess.run(result)) #gatherの結果


実行結果

param shape:      (2, 8, 3)

indices shape: (2, 4)
result shape: (2, 4, 3)

param
[[[0. 0. 9.]
[0. 1. 9.]
[0. 2. 9.]
[0. 3. 9.]
[0. 4. 9.]
[0. 5. 9.]
[0. 6. 9.]
[0. 7. 9.]]

[[1. 0. 9.]
[1. 1. 9.]
[1. 2. 9.]
[1. 3. 9.]
[1. 4. 9.]
[1. 5. 9.]
[1. 6. 9.]
[1. 7. 9.]]]

indices
[[1 0 0 4]
[1 3 4 6]]

result
[[[0. 1. 9.]
[0. 0. 9.]
[0. 0. 9.]
[0. 4. 9.]]

[[1. 1. 9.]
[1. 3. 9.]
[1. 4. 9.]
[1. 6. 9.]]]