4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tf.batch_gatherの使い方メモ

Posted at

#tf.gatherと仲間たち

まず、gatheの処理をしてくれる仲間たちを紹介する。

  • tf.gather 指定したaxisに沿ってスライスしたものをgatherする。
  • tf.gather_nd 自由度が高く、任意の要素からgatherされる。
  • tf.batch_gather batchに沿って処理される。indicesでスライスされる。
    今回はバッチごとにgather処理をしてくれるtf.batch_gatherを使ってみた。

#tf.batch_gatherの使い方

##tf.batch_gatherの解説

tf.batch_gather(
    params,
    indices,
    name=None
)

prams(Tensor): [A1, ..., AN-1, AN, B1, ..., BM]
indices(Tensor): [A1, ..., AN-1, C]
result(Tensor): [A1, ..., AN-1, C, B1, ..., BM]

  • ANでスライスされ、A1〜AN-1の構造はそのまま維持される。
  • CにANからgatherされる値が入る。
  • Cの各値はANの大きさ以下でなければいけない。

詳しくはTensorFlowの公式ドキュメントを参考にしてもらいたい。

##今回の処理
点群屋さんなので、点群を例に説明します。
8点から成る3次元点群を用意した。ここから4点を選びgatherするような処理を行う。
Bはバッチ、Nは点群の点、Cはチャンネル(3次元点群なのでxyz)

gather.jpg

##サンプルコード

batch_gather.py
import tensorflow as tf

#元の点群(2batch x 8点 x 3次元)
param = tf.constant([[[0.,0.,9.],[0.,1.,9.], [0.,2.,9.], [0.,3.,9.], [0.,4.,9.], [0.,5.,9.], [0.,6.,9.], [0.,7.,9.]],
                     [[1.,0.,9.],[1.,1.,9.], [1.,2.,9.], [1.,3.,9.], [1.,4.,9.], [1.,5.,9.], [1.,6.,9.], [1.,7.,9.]]]) # B x N x C(2 x 8 x 3)
print("param shape:     ",param.shape)

#gather元を選ぶテンソル(2batch x 4点)
indices = tf.constant([[1,0,0,4],
                       [1,3,4,6]])    # B x N
print("indices shape:   ",indices.shape)

#gatherする
result = tf.batch_gather(param, indices)
print("result shape:    ",result.shape)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("param\n",sess.run(param))     #入力の確認
print("indices\n",sess.run(indices)) #indiciesの確認
print("result\n",sess.run(result))   #gatherの結果

実行結果

param shape:      (2, 8, 3)
indices shape:    (2, 4)
result shape:     (2, 4, 3)

param
 [[[0. 0. 9.]
  [0. 1. 9.]
  [0. 2. 9.]
  [0. 3. 9.]
  [0. 4. 9.]
  [0. 5. 9.]
  [0. 6. 9.]
  [0. 7. 9.]]

 [[1. 0. 9.]
  [1. 1. 9.]
  [1. 2. 9.]
  [1. 3. 9.]
  [1. 4. 9.]
  [1. 5. 9.]
  [1. 6. 9.]
  [1. 7. 9.]]]

indices
 [[1 0 0 4]
 [1 3 4 6]]

result
 [[[0. 1. 9.]
  [0. 0. 9.]
  [0. 0. 9.]
  [0. 4. 9.]]

 [[1. 1. 9.]
  [1. 3. 9.]
  [1. 4. 9.]
  [1. 6. 9.]]]
4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?