1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Tensorflowでnumpyの挙動をさせる

Last updated at Posted at 2021-11-06

はじめに

※この記事の内容はTensorflow 2.5以上が対象です(2.4だとnp_configenable_numpy_behaviorが存在しないためです)。


Tensorflowで演算処理をするとき、通常では以下のような書き方はできません:

import tensorflow as tf

x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)

# こういう計算をしたい
mm = x.max(axis=0).mean()
AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'max'

だったり、

# np.chooseやnp.takeの挙動をさせたい(fancy indexingをさせたい)
collect = x[:,[2,0,1]]
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got [2, 0, 1]

つまり、numpyで使える書き方ができません。jaxやpytorchではできますが、tensorflowでは対応していないのです。tensorflowが時々使いにくいとかコメントされているのは、こういう所もあるかもしれませんね。

ただ、一応tensorflowでもnumpyのような挙動をさせる手法はあります

numpyの挙動をさせる

簡単です。2行で終わります。

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

# or
from tensorflow.experimental.numpy import enable_numpy_behavior
enable_numpy_behavior()

# ok
mm = x.max(axis=0).mean()

# これもok(中身はtf.gatherがやってる)
x[:,[2,0,1]]

できるようになること

  • fancy indexingが使えるようになる
  • 次のmethodが使えるようになる
    • T, transpose
    • ndim
    • size
    • tolist
    • reshape
    • ravel
    • clip
    • astype
    • max
    • mean
    • min
    • data

できないこと

上記以外(例えば、tf.Tensor.sumなど)は使えない。
参照: tensorflow/python/ops/numpy_ops.py

注意点

np_config.enable_numpy_behavior()tensorflowを遅くします。有効にするだけでもです。
おそらくfancy indexingの設定(ops.enable_numpy_style_slicing)かなとは思ってますが、学習で使ったりすると結構時間に差が出てきます。
別にfancy indexingを使わないのであれば、次の関数にしてtf.Tensorにmethodだけ追加しておけばokです。なんだったら自分でsetattr(ops.Tensor, "関数名")を入れて無理やり追加しちゃってもいい

from tensorflow.python.ops.numpy_ops import enable_numpy_methods_on_tensor
enable_numpy_methods_on_tensor()

どうしてもfancy indexingを使いたい場合は、おとなしくtf.gathertf.gather_ndを使うしかなさそうです。

おまけ

numpyの関数の名前でtensorflowを使うときは、tf.experimental.numpyをimportすることでできます。内部的にはtensorflowの関数の名前を変えたようなものです。
https://www.tensorflow.org/api_docs/python/tf/experimental/numpy

import tf.experimental.numpy as tnp # これをimport
from tensorflow.python.ops.numpy_ops import enable_numpy_methods_on_tensor
enable_numpy_methods_on_tensor()

x = tnp.array([[1, 2, 3], [4, 5, 6]], dtype=tnp.float32)
print(x.max(axis=0).mean())

そろそろexperimentalは外しても良いような…

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?