はじめに(やりたいこと)
最近はpytorchから離れ、tensorflowとお友達になるべく日々奮闘しています。
もちろんpytorchのことを嫌いになったわけではないですよ。
行列演算をしているときにtf.broadcast_to
でtensorをいい感じにしたいことがあります。でも、もしそのtensorがバッチだった場合……
>>> x = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]])
>>> tf.broadcast_to(x, [2, 5, 5])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/***/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 859, in broadcast_to
input, shape, name=name, ctx=_ctx)
File "/home/***/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 905, in broadcast_to_eager_fallback
attrs=_attrs, ctx=ctx, name=name)
File "/home/***/python3.7/site-packages/tensorflow/python/eager/execute.py", line 55, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [2,5] vs. [2,5,5] [Op:BroadcastTo]
と出てブロードキャストできません。それもそのはず、
Two shapes are compatible if for each dimension pair they are either equal or one of them is one.
(2つのshapeは各次元のペアについて、それらが等しいか、どちらかが1である場合に互換性があります。)
ということで、tf.broadcast_to
を使わずに同じ結果を得ようと思います。
こんな感じのtensorが得られればゴールです。[batch, dim]
を[batch, dim, dim]
にしたい、ということです。
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]],
[[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0]]], dtype=int32)>
解決策
適当に実装したのが方法1なのですが、同僚から方法2, 3についても教えて頂きました。
方法1:repeat
とreshape
を組み合わせる
各batch
ごとにdim
回リピートして、reshape
で欲しい形に変更すれば良いです。
x = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]) # [batch, dim]
y2 = tf.repeat(x, repeats=x.shape[1], axis=0) # [batch*dim, dim]
y2 = tf.reshape(y2, [2, x.shape[1], x.shape[1]]) # [batch, dim, dim]
>>> y2
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]],
[[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0]]], dtype=int32)>
方法2:expand_dims
で次元を拡大してrepeat
batch
のあとに1次元増やして、増えた次元を指定してdim
回分repeat
します。
方法1の手順を逆にした感じですね。
x = tf.constant([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]) # [batch, dim]
y2 = tf.expand_dims(x, axis=1) # [batch, 1, dim]
y2 = tf.repeat(y2, repeats=[x.shape[1]], axis=1) # [batch, dim, dim]
>>> y2
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]],
[[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0]]], dtype=int32)>
方法3:map_fn
を使う
map_fn
を用いてバッチごとにbroadcast_to
を適用します。
tf.map_fn(func, elems=x)
でx
の各バッチをfunc
に入力します。
>>> tf.map_fn(lambda x: tf.broadcast_to(x, (x.shape[0], x.shape[0])), elems=x)
<tf.Tensor: shape=(2, 5, 5), dtype=int32, numpy=
array([[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]],
[[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0],
[6, 7, 8, 9, 0]]], dtype=int32)>
おわりに
思想的にあくまでbroadcastをしたいのであれば方法3を、別に気にしないのであればどれでも良いのかなと思います。
そしてオチなのですが、結局この処理は使わないことになっちゃいました。故に計算速度については調べていません……。