0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

tf.broadcast_toを各バッチごとに行いたい

Last updated at Posted at 2023-02-02

はじめに(やりたいこと)

最近はpytorchから離れ、tensorflowとお友達になるべく日々奮闘しています。
もちろんpytorchのことを嫌いになったわけではないですよ。

行列演算をしているときにtf.broadcast_toでtensorをいい感じにしたいことがあります。でも、もしそのtensorがバッチだった場合……

>>> x = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]])
>>> tf.broadcast_to(x, [2, 5, 5])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/***/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 859, in broadcast_to
    input, shape, name=name, ctx=_ctx)
  File "/home/***/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 905, in broadcast_to_eager_fallback
    attrs=_attrs, ctx=ctx, name=name)
  File "/home/***/python3.7/site-packages/tensorflow/python/eager/execute.py", line 55, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [2,5] vs. [2,5,5] [Op:BroadcastTo]

と出てブロードキャストできません。それもそのはず、

Two shapes are compatible if for each dimension pair they are either equal or one of them is one.
(2つのshapeは各次元のペアについて、それらが等しいか、どちらかが1である場合に互換性があります。)

ということで、tf.broadcast_toを使わずに同じ結果を得ようと思います。
こんな感じのtensorが得られればゴールです。[batch, dim][batch, dim, dim]にしたい、ということです。

<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5]],

       [[6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0]]], dtype=int32)>

解決策

適当に実装したのが方法1なのですが、同僚から方法2, 3についても教えて頂きました。

方法1:repeatreshapeを組み合わせる

batchごとにdim回リピートして、reshapeで欲しい形に変更すれば良いです。

x = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]) # [batch, dim]
y2 = tf.repeat(x, repeats=x.shape[1], axis=0) # [batch*dim, dim]
y2 = tf.reshape(y2, [2, x.shape[1], x.shape[1]]) # [batch, dim, dim]

>>> y2
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5]],

       [[6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0]]], dtype=int32)>

方法2:expand_dimsで次元を拡大してrepeat

batchのあとに1次元増やして、増えた次元を指定してdim回分repeatします。
方法1の手順を逆にした感じですね。

x = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]) # [batch, dim]
y2 = tf.expand_dims(x, axis=1) # [batch, 1, dim]
y2 = tf.repeat(y2, repeats=[x.shape[1]], axis=1) # [batch, dim, dim]

>>> y2
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5]],

       [[6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0]]], dtype=int32)>

方法3:map_fnを使う

map_fnを用いてバッチごとにbroadcast_toを適用します。
tf.map_fn(func, elems=x)xの各バッチをfuncに入力します。

>>> tf.map_fn(lambda x: tf.broadcast_to(x, (x.shape[0], x.shape[0])), elems=x)
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5]],

       [[6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0],
        [6, 7, 8, 9, 0]]], dtype=int32)>

おわりに

思想的にあくまでbroadcastをしたいのであれば方法3を、別に気にしないのであればどれでも良いのかなと思います。

そしてオチなのですが、結局この処理は使わないことになっちゃいました。故に計算速度については調べていません……。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?