tensorflowで行列を掛けるためにreshapeすることがあると思います。自分で書いていて、こんなことしてもいいのかな?と思ったことがあって確認してみたので、メモを残しておくことにします。
したいこと
import tensorflow as tf
x = tf.placeholder(tf.float32, [2, None, 2])
y = tf.reshape([-1, 2])
z = tf.reshape([2, -1, 2])
のように、Noneがshapeの途中に出てくる場合に、別のshapeにしてから元に戻せるかということが気になりました。(実際には、途中で、yに行列を掛けて、その結果を本来のshapeに戻したい)
調べた
ipythonで、tf.reshape??のように、クエスチョンマークを2つつけると、関数の定義が見れます。もっとも、コメントと他の関数を呼び出しくらいしか出てきませんが、コメントに自分の興味にあったことが書いてありました。
意訳すると、shapeに整数の列を入れる場合、"-1"を入れるとしたら、それは一つしか入れてはいけない。一つだけなら、他の次元から逆算できるから大丈夫だという話です。逆に、この"-1"はどこに入れなければいけないという制限が書いてないので、どこでもいいということでしょう。
試してみた
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, [2, None, 2])
y = tf.reshape(x, [-1, 2])
z = tf.reshape(y, [2, -1, 2])
a = np.array([[[1,2],[3,4]],[[5,6],[7,8]]], dtype=np.float32)
with tf.Session() as sess:
print(a)
print(sess.run(z, {x: a}))
無事、aとzが同じものが表示されますね。めでたしめでたし。
本当にしたかったこと(数学っぽく)
f: \mathbb{R}^{n_k} \rightarrow \mathbb{R}^m
という線形写像で、
1\otimes\dots \otimes 1 \otimes f: \mathbb{R}^{n_1} \otimes \dots \otimes \mathbb{R}^{n_{k-1}} \otimes \mathbb{R}^{n_k} \longrightarrow \mathbb{R}^{n_1} \otimes \dots \otimes \mathbb{R}^{n_{k-1}} \otimes \mathbb{R}^{m}
のように最後だけ適用させたかった。これをするのにreshapeする必要があるのは、Tensorの名前を冠しているのに・・・と思ってしまう・・・