tf.whereでのイコール条件の指定ではまった
見出し通り。なんかハマった。
自分が、イコール指定できてると思ったところで落ちたりした。
ということで、確認&解決したので、その備忘録
tf.whereでの基本的な条件指定 (確認)
まずは、
th.whereの基本的な条件
GT(Greater Than)
,LT(Less Than)
,GE(Greather than Equal)
,LE(Less than Equal)
,EQ(Equal)
,NE(Not Equal)
の指定方法の確認してみた。
where_conditions.py
import tensorflow as tf
def print_tensor(t, name=''):
print('## : ', name)
print(t)
print('')
t = tf.constant([0, 1, 1, 0, 1])
print_tensor(t, name='t')
gt_t = tf.where(t > 0)
print_tensor(gt_t, name='gt_t (0 < t)')
lt_t = tf.where(t < 1)
print_tensor(lt_t, name='lt_t (t < 1)')
ge_t = tf.where(t >= 1)
print_tensor(ge_t, name='ge_t (1 <= t)')
le_t = tf.where(t <= 0)
print_tensor(le_t, name='le_t (t <= 0)')
eq_t = tf.where(t == 0)
print_tensor(eq_t, name='eq_t (t == 0)')
ne_t = tf.where(t != 0)
print_tensor(ne_t, name='ne_t (t != 0)')
実行結果が、これ↓
Using TensorFlow backend.
## : t
Tensor("Const:0", shape=(5,), dtype=int32)
## : gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
## : lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
## : ge_t (1 <= t)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
## : le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
## : eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 0), dtype=int64)
## : ne_t (t != 0)
Tensor("Where_5:0", shape=(?, 0), dtype=int64)
ん?
Equal(==)
, NotEqual(!=)
の次元が、0
でおかしい。
以下のように変更して、実行すると
where_conditions.py
import tensorflow as tf
import keras.backend as KB
def print_tensor(t, name=''):
print('## : ', name)
print(t)
v = KB.get_value(t)
print(v)
print('')
~~ 省略 ~~
もちろん、
## : t
Tensor("Const:0", shape=(5,), dtype=int32)
[0 1 1 0 1]
## : gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
## : lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : ge_t (1 <= t)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
## : le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 0), dtype=int64)
Traceback (most recent call last):
~~ 省略 ~~
のように
eq_t
のKB.get_value
時に落ちる。
(eq_t
がなければ、 ne_t
でおちる。)
落ちてたのこの辺のせいね。
tf.whereでのイコール条件指定 (正解)
where_conditions_fix.py
import tensorflow as tf
import keras.backend as KB
def print_tensor(t, name=''):
print('## : ', name)
print(t)
v = KB.get_value(t)
print(v)
print('')
t = tf.constant([0, 1, 1, 0, 1])
print_tensor(t, name='t')
gt_t = tf.where(t > 0)
print_tensor(gt_t, name='gt_t (0 < t)')
lt_t = tf.where(t < 1)
print_tensor(lt_t, name='lt_t (t < 1)')
ge_t = tf.where(t >= 1)
print_tensor(ge_t, name='ge_t (1 <= 1)')
le_t = tf.where(t <= 0)
print_tensor(le_t, name='le_t (t <= 0)')
# eq_t = tf.where(t == 0) <- Not expected behavior !
eq_t = tf.where(tf.equal(t, 0))
print_tensor(eq_t, name='eq_t (t == 0)')
# ne_t = tf.where(t != 0) <- Not expected behavior !
ne_t = tf.where(tf.not_equal(t, 0))
print_tensor(ne_t, name='ne_t (t != 0)')
出力は、以下
## : t
Tensor("Const:0", shape=(5,), dtype=int32)
[0 1 1 0 1]
## : gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
## : lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : ge_t (1 <= 1)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
## : le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : ne_t (t != 0)
Tensor("Where_5:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
これで、無事、期待通りの動作になった。
にしても、<=
, >=
は使えるのに、==
, !=
は、アウトなんてあるんですね。(わかりづらい。。。)
tf.whereでの基本的な条件指定 (tfの関数使用)
おまけだけど、もちろん、GT
, LT
, GE
, LE
も、tfの関数を使っての指定もできる。
where_conditions_tffunc.py
~~ 省略 ~~
gt_t = tf.where(tf.greater(t, 0))
print_tensor(gt_t, name='gt_t (0 < t)')
lt_t = tf.where(tf.less(t, 1))
print_tensor(lt_t, name='lt_t (t < 1)')
ge_t = tf.where(tf.greater_equal(t, 1))
print_tensor(ge_t, name='ge_t (1 <= 1)')
le_t = tf.where(tf.less_equal(t, 0))
print_tensor(le_t, name='le_t (t <= 0)')
eq_t = tf.where(tf.equal(t, 0))
print_tensor(eq_t, name='eq_t (t == 0)')
ne_t = tf.where(tf.not_equal(t, 0))
print_tensor(ne_t, name='ne_t (t != 0)')
## : t
[0 1 1 0 1]
## : gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
## : lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : ge_t (1 <= 1)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
## : le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 1), dtype=int64)
[[0]
[3]]
## : ne_t (t != 0)
Tensor("Where_5:0", shape=(?, 1), dtype=int64)
[[1]
[2]
[4]]
tf.whereでの複合条件指定
さらにおまけ。複合条件は、こんな感じ。
where_composite_conditions.py
import tensorflow as tf
import keras.backend as KB
def print_tensor(t, name=''):
print('## : ', name)
print(t)
v = KB.get_value(t)
print(v)
print('')
def pt(t, name=''):
print_tensor(t, name=name)
t = tf.constant([3, 1, 2, 0, 4])
pt(t, name='t')
and_t = tf.where((t < 4) & (t > 1))
print_tensor(and_t, name='and_t (1 < t < 4)')
or_t = tf.where((t < 2) | (t > 3))
出力は、以下
## : t
Tensor("Const:0", shape=(5,), dtype=int32)
[3 1 2 0 4]
## : and_t (1 < t < 4)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[0]
[2]]
## : or_t (t < 2, 3 < t)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[1]
[3]
[4]]