0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

tensorflowのRandomcropでエラーが出る

Last updated at Posted at 2024-04-15

dataaugmentaionとは

一般にデータ増強などと訳される処理のことであり、機械学習モデル(特にNN)において学習を行う場合に過学習しないように行う処理である。
過学習とは「モデルが学習データのみへ過剰に適合してしまう」ことである。dataaugmentaionは訓練データにランダムノイズを加えるなどして、最も学習させたい特徴量のみを学習させるよう(正規化)に仕向ける。

kaggleのlearnにあるcomputer visionコースの6コマ目が詳しく解説している。

またtensorflowにおけるdataaugmentaionの実装はtensorflowチュートリアルが詳しい

今回出たエラー

次のようなコードを実行した

def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "kyoushi1": tf.io.FixedLenFeature([], tf.int64),
        "kyoushi2": tf.io.FixedLenFeature([], tf.int64),
        "ly": tf.io.FixedLenFeature([], tf.int64),
        "lx": tf.io.FixedLenFeature([], tf.int64),
        "lw": tf.io.FixedLenFeature([], tf.int64),
        "lh": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    return example

def augmentation(image, training = True):
    if training:
        aug = tf.keras.Sequential([
                layers.Resizing(20, 20),
                layers.RandomCrop(12, 12),
                layers.Rescaling(1./255),])
    else:
        aug = tf.keras.Sequential([
                layers.Resizing(12, 12),
                layers.Rescaling(1./255),])
    image = aug(image)
    return image

def prepare_fit_sample(features):
    img_feat = features['image']
    l_im = tf.image.crop_to_bounding_box(img_feat, ly, lx, lh, lw)
    l_im = augmentation(l_im, False)

    y = [features['kyoushi1'], features['kyoushi2']]

    return l_im, y

if __name__ == '__main__':
    filenames = hogehoge.tfrec
    dataset = (
            tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
            .map(parse_tfrecord, num_parallel_calls=AUTO)
            .map(prepare_sample, num_parallel_calls=AUTO)
            )

実行時のエラーメッセージは以下のとおりである

tf.function only supports singleton tf.variables created on the first call. make sure the tf.variable is only created once or created outside tf.function. see https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

原因究明

randomcropをコメントアウトすると正常に実行できたことからこいつが悪さをしていると判明。
エラーメッセージからtfdata.map内でrandomcropレイヤーを呼び出す(インスタンス化)することでtf.randomgenerator内のtf.variableが干渉してしまっているようであると考えた。

解決法

Sequantial形式でのlayer構築をmapの外側で行った

def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "kyoushi1": tf.io.FixedLenFeature([], tf.int64),
        "kyoushi2": tf.io.FixedLenFeature([], tf.int64),
        "ly": tf.io.FixedLenFeature([], tf.int64),
        "lx": tf.io.FixedLenFeature([], tf.int64),
        "lw": tf.io.FixedLenFeature([], tf.int64),
        "lh": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    return example

######################################################deleted augmentation

def prepare_fit_samplemod(features):
    img_feat = features['image']
    l_im = tf.image.crop_to_bounding_box(img_feat, ly, lx, lh, lw)
    #l_im = augmentation(l_im, False)####deleted augmentation

    y = [features['kyoushi1'], features['kyoushi2']]

    return l_im, y

if __name__ == '__main__':
    filenames = hogehoge.tfrec
    #######################################################inserted sequantial
    aug = tf.keras.Sequential([
        layers.Resizing(20, 20),
        layers.RandomCrop(12, 12),
        layers.Rescaling(1./255),])
    #######################################################
    dataset = (
            tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
            .map(parse_tfrecord, num_parallel_calls=AUTO)
            .map(prepare_samplemod, num_parallel_calls=AUTO)
            .map(lambda x , y: (aug(x), y), num_parallel_calls=AUTO)################call layer
            )

参考

tensorflowのマニュアルを参照した

またこちらのstackoverflowの記事も参考になる

どちらの記事も本エラーとはあまり関係ないかも?

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?