0
0

More than 1 year has passed since last update.

Tensor Flow エラー InvalidArgumentError: required broadcastable shapes 解決 2021/11/17

Posted at

本日のエラー「 InvalidArgumentError: required broadcastable shapes」
Error.png

本日に UNETのモデルを実行するとき、モデルのクラスを1から11まで増して、このエラーが引き起こされてしまう。原因は、Tensroflow
このエラーの原因は、私のデータを読み込む関数で、to_categoricalがなくて、tf_parseの入力サイズも違いました。

下記はエラーを引き起こされたコード

def read_mask(path):
    path = path.decode()
    x = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
    x = np.expand_dims(x, axis=-1)
    x = x.astype(np.double)

    return x
def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([256, 256, 3])
    y.set_shape([256, 256,1])

    return x, y

上記のコードは、複数のクラスの機能がないですね。マスクの画像を読み込むと、GRAYSCALEの画像の一のクラスを仮定されています。そして、tf_parseでは、マスクのyの変数のサイズは、カラスの次元は1に設定されています。複数のクラスを使用したい場合は、この点を更新しなければなりません。

改善したコード

from tensorflow.keras.utils import to_categorical
def read_mask(path,N_CLASSES=N_CLASSES):
    path = path.decode()
    x = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
    x = np.expand_dims(x, axis=-1)
    x= to_categorical(x, num_classes=N_CLASSES) #カラス化
    x = x.astype(np.double)

    return x


def tf_parse(x, y,N_CLASSES=N_CLASSES):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([256, 256, 3])
    y.set_shape([256, 256,N_CLASSES]) #正しいカラス数により次元

    return x, y

一つの注意点は、「x = x.astype(np.double)」の部分は、 「to_categorical」の後で置いていなければなりません。逆の場合は、他のエラーを引き起こされます。
Error2.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0