はじめに
※この記事の内容はTensorflow 2.5以上が対象です(2.4だとnp_config
やenable_numpy_behavior
が存在しないためです)。
Tensorflowで演算処理をするとき、通常では以下のような書き方はできません:
import tensorflow as tf
x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
# こういう計算をしたい
mm = x.max(axis=0).mean()
AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'max'
だったり、
# np.chooseやnp.takeの挙動をさせたい(fancy indexingをさせたい)
collect = x[:,[2,0,1]]
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got [2, 0, 1]
つまり、numpyで使える書き方ができません。jaxやpytorchではできますが、tensorflowでは対応していないのです。tensorflowが時々使いにくいとかコメントされているのは、こういう所もあるかもしれませんね。
ただ、一応tensorflowでもnumpyのような挙動をさせる手法はあります
numpyの挙動をさせる
簡単です。2行で終わります。
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
# or
from tensorflow.experimental.numpy import enable_numpy_behavior
enable_numpy_behavior()
# ok
mm = x.max(axis=0).mean()
# これもok(中身はtf.gatherがやってる)
x[:,[2,0,1]]
できるようになること
- fancy indexingが使えるようになる
- 次のmethodが使えるようになる
-
T
,transpose
ndim
size
tolist
reshape
ravel
clip
astype
max
mean
min
data
-
できないこと
上記以外(例えば、tf.Tensor.sum
など)は使えない。
参照: tensorflow/python/ops/numpy_ops.py
注意点
np_config.enable_numpy_behavior()
はtensorflowを遅くします。有効にするだけでもです。
おそらくfancy indexingの設定(ops.enable_numpy_style_slicing
)かなとは思ってますが、学習で使ったりすると結構時間に差が出てきます。
別にfancy indexingを使わないのであれば、次の関数にしてtf.Tensor
にmethodだけ追加しておけばokです。なんだったら自分でsetattr(ops.Tensor, "関数名")
を入れて無理やり追加しちゃってもいい
from tensorflow.python.ops.numpy_ops import enable_numpy_methods_on_tensor
enable_numpy_methods_on_tensor()
どうしてもfancy indexingを使いたい場合は、おとなしくtf.gather
かtf.gather_nd
を使うしかなさそうです。
おまけ
numpyの関数の名前でtensorflowを使うときは、tf.experimental.numpy
をimportすることでできます。内部的にはtensorflowの関数の名前を変えたようなものです。
https://www.tensorflow.org/api_docs/python/tf/experimental/numpy
import tf.experimental.numpy as tnp # これをimport
from tensorflow.python.ops.numpy_ops import enable_numpy_methods_on_tensor
enable_numpy_methods_on_tensor()
x = tnp.array([[1, 2, 3], [4, 5, 6]], dtype=tnp.float32)
print(x.max(axis=0).mean())
そろそろexperimental
は外しても良いような…