LoginSignup
7
3

More than 5 years have passed since last update.

tf.whereでのイコール条件の指定方法

Posted at

tf.whereでのイコール条件の指定ではまった

見出し通り。なんかハマった。
自分が、イコール指定できてると思ったところで落ちたりした。
ということで、確認&解決したので、その備忘録

tf.whereでの基本的な条件指定 (確認)

まずは、
th.whereの基本的な条件
GT(Greater Than),LT(Less Than),GE(Greather than Equal),LE(Less than Equal),EQ(Equal),NE(Not Equal)
の指定方法の確認してみた。

where_conditions.py
import tensorflow as tf

def print_tensor(t, name=''):
    print('## : ', name)
    print(t)
    print('')

t = tf.constant([0, 1, 1, 0, 1])
print_tensor(t, name='t')

gt_t = tf.where(t > 0)
print_tensor(gt_t, name='gt_t (0 < t)')

lt_t = tf.where(t < 1)
print_tensor(lt_t, name='lt_t (t < 1)')

ge_t = tf.where(t >= 1)
print_tensor(ge_t, name='ge_t (1 <= t)')

le_t = tf.where(t <= 0)
print_tensor(le_t, name='le_t (t <= 0)')

eq_t = tf.where(t == 0)
print_tensor(eq_t, name='eq_t (t == 0)')

ne_t = tf.where(t != 0)
print_tensor(ne_t, name='ne_t (t != 0)')

実行結果が、これ↓

Using TensorFlow backend.
## :  t
Tensor("Const:0", shape=(5,), dtype=int32)

## :  gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)

## :  lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)

## :  ge_t (1 <= t)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)

## :  le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)

## :  eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 0), dtype=int64)

## :  ne_t (t != 0)
Tensor("Where_5:0", shape=(?, 0), dtype=int64)

ん?
Equal(==), NotEqual(!=)の次元が、0でおかしい。

以下のように変更して、実行すると

where_conditions.py
import tensorflow as tf
import keras.backend as KB

def print_tensor(t, name=''):
    print('## : ', name)
    print(t)
    v = KB.get_value(t)
    print(v)
    print('')

~~ 省略 ~~

もちろん、

## :  t
Tensor("Const:0", shape=(5,), dtype=int32)
[0 1 1 0 1]

## :  gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

## :  lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  ge_t (1 <= t)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

## :  le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 0), dtype=int64)
Traceback (most recent call last):
~~ 省略 ~~

のように
eq_tKB.get_value時に落ちる。
(eq_tがなければ、 ne_tでおちる。)
落ちてたのこの辺のせいね。

tf.whereでのイコール条件指定 (正解)

where_conditions_fix.py
import tensorflow as tf
import keras.backend as KB

def print_tensor(t, name=''):
    print('## : ', name)
    print(t)
    v = KB.get_value(t)
    print(v)
    print('')

t = tf.constant([0, 1, 1, 0, 1])
print_tensor(t, name='t')

gt_t = tf.where(t > 0)
print_tensor(gt_t, name='gt_t (0 < t)')

lt_t = tf.where(t < 1)
print_tensor(lt_t, name='lt_t (t < 1)')

ge_t = tf.where(t >= 1)
print_tensor(ge_t, name='ge_t (1 <= 1)')

le_t = tf.where(t <= 0)
print_tensor(le_t, name='le_t (t <= 0)')

#eq_t = tf.where(t == 0) <- Not expected behavior !
eq_t = tf.where(tf.equal(t, 0))
print_tensor(eq_t, name='eq_t (t == 0)')

#ne_t = tf.where(t != 0) <- Not expected behavior !
ne_t = tf.where(tf.not_equal(t, 0))
print_tensor(ne_t, name='ne_t (t != 0)')

出力は、以下

## :  t
Tensor("Const:0", shape=(5,), dtype=int32)
[0 1 1 0 1]

## :  gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

## :  lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  ge_t (1 <= 1)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

## :  le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  ne_t (t != 0)
Tensor("Where_5:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

これで、無事、期待通りの動作になった。
にしても、<=, >=は使えるのに、==, !=は、アウトなんてあるんですね。(わかりづらい。。。)

tf.whereでの基本的な条件指定 (tfの関数使用)

おまけだけど、もちろん、GT, LT, GE, LEも、tfの関数を使っての指定もできる。

where_conditions_tffunc.py
~~ 省略 ~~

gt_t = tf.where(tf.greater(t, 0))
print_tensor(gt_t, name='gt_t (0 < t)')

lt_t = tf.where(tf.less(t, 1))
print_tensor(lt_t, name='lt_t (t < 1)')

ge_t = tf.where(tf.greater_equal(t, 1))
print_tensor(ge_t, name='ge_t (1 <= 1)')

le_t = tf.where(tf.less_equal(t, 0))
print_tensor(le_t, name='le_t (t <= 0)')

eq_t = tf.where(tf.equal(t, 0))
print_tensor(eq_t, name='eq_t (t == 0)')

ne_t = tf.where(tf.not_equal(t, 0))
print_tensor(ne_t, name='ne_t (t != 0)')
## :  t
[0 1 1 0 1]

## :  gt_t (0 < t)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

## :  lt_t (t < 1)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  ge_t (1 <= 1)
Tensor("Where_2:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

## :  le_t (t <= 0)
Tensor("Where_3:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  eq_t (t == 0)
Tensor("Where_4:0", shape=(?, 1), dtype=int64)
[[0]
 [3]]

## :  ne_t (t != 0)
Tensor("Where_5:0", shape=(?, 1), dtype=int64)
[[1]
 [2]
 [4]]

tf.whereでの複合条件指定

さらにおまけ。複合条件は、こんな感じ。

where_composite_conditions.py
import tensorflow as tf
import keras.backend as KB

def print_tensor(t, name=''):
    print('## : ', name)
    print(t)
    v = KB.get_value(t)
    print(v)
    print('')

def pt(t, name=''):
    print_tensor(t, name=name)

t = tf.constant([3, 1, 2, 0, 4])
pt(t, name='t')

and_t = tf.where((t < 4) & (t > 1))
print_tensor(and_t, name='and_t (1 < t < 4)')

or_t = tf.where((t < 2) | (t > 3))

出力は、以下

## :  t
Tensor("Const:0", shape=(5,), dtype=int32)
[3 1 2 0 4]

## :  and_t (1 < t < 4)
Tensor("Where:0", shape=(?, 1), dtype=int64)
[[0]
 [2]]

## :  or_t (t < 2, 3 < t)
Tensor("Where_1:0", shape=(?, 1), dtype=int64)
[[1]
 [3]
 [4]]
7
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3