LoginSignup
6
0

More than 1 year has passed since last update.

tf.data、TFRecordsを使った画像読み込み

Last updated at Posted at 2021-08-27

チュートリアルとかにある全てnumpy配列に格納するタイプではなく、大規模データなどで使われる、バッチ毎に読み込む方法。
あまり最近のが無かったので、Qiitaで記事にした。
tensorflowは2.6.0を前提。

import glob
import math
import os
from typing import List, Tuple, Union

import albumentations as albu # 必要に応じて
import tensorflow as tf
import numpy as np # 必要に応じて


AUTOTUNE = tf.data.experimental.AUTOTUNE

tf.data

画像の読み込み

tf.dataを使う場合、まずは読み込む用の関数を用意する。
簡単に書けば、こういうものになる。

@tf.function(experimental_follow_type_hints=True)
def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor:
    # tf.ioでファイルを読み込んでから画像にする
    _read = tf.io.read_file(img_path)
    img = tf.image.decode_image(_read, channels=3, expand_animations=False)
    return img

高速化したいので、tf.functionでラッパーする。
(tf.dataを使うからargは絶対Tensorになるんだけど)、experimental_follow_type_hints=Trueを入れてtype hintsにtf.Tensorを入れておくと、たとえpythonのstringが入ってきても自動的にTensor型に変換してくれるので、再トレーシングとかが気にしなくても良くなる。これは今回のようなtf.dataに限らず、いろんな場面で使えるので、覚えておいて損はない。

tf.decode_imageは自動的にファイル形式を識別して読み込んでくれるので、decode_jpegとかdecode_pngとかよりも使いやすい。expand_animations=Falseでgifとか入ってきてもちゃんと3次元で返ってくれるようになる。ただし、jpegは注意点があるので、こちらを参照

前処理

やり方は主に2種類。KerasのPreprocessing Layerも入るかもしれない。
1. tensorflow公式が準備している前処理関数を使う(tf.image、tensorflow addonsのtfa.image )
2. 他モジュールやImageDataGeneratorを、tf.py_function経由で前処理をする(※)

バリエーションが多い順に並べると後者が強い。ただし、TPUやマルチGPU環境ではうまく動作しないらしい

1の例で、ランダムクロップと正規化を反映させたい場合、こんな感じ。

def preprocess(img: tf.Tensor) -> tf.Tensor:
    crop_size = tf.constant((128,128, img.shape[-1]))
    cropped = tf.image.random_crop(img, crop_size)
    normalized = tf.cast(cropped, tf.float32) / 255.
    return normalized

@tf.function(experimental_follow_type_hints=True)
def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor:
    # tf.ioでファイルを読み込んでから画像にする
    _read = tf.io.read_file(img_path)
    img = tf.image.decode_image(_read, channels=3, expand_animations=False)
    return preprocess(img)

他にも、いろいろな前処理があるので、公式ドキュメント画像の水増し方法をTensorFlowのコードから学ぶなどを参照。

2の場合、albumentationsを使ってみるとこんな感じ。transformsはだいぶ昔に使ってたものを引っ張ってきた。

# callの際に、適当な前処理をどれか1つ選んで行う。
transforms = albu.OneOf([
                albu.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=90),
                albu.GaussNoise(),
                albu.Equalize(),
                albu.ElasticTransform(),
                albu.GaussianBlur(),
                albu.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
                albu.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5),
                albu.HorizontalFlip()])

def preprocess(img) -> tf.Tensor:
    augmented = transforms(image=img.numpy())['image']
    return tf.convert_to_tensor(augmented, tf.float32)

@tf.function(experimental_follow_type_hints=True)
def read_and_preprocess(img_path: tf.Tensor) -> tf.Tensor:
    # tf.ioでファイルを読み込んでから画像にする
    _read = tf.io.read_file(img_path)
    img = tf.image.decode_image(_read, channels=3, expand_animations=False)
    img = tf.py_function(preprocess, [img], [tf.float32])
    return preprocess(img)

Datasetへ

ここは公式のチュートリアルの方が良いので、公式を見るべし
自分の場合、こういう形をよく作っている。

def generate_dataset(image_list: List[str], batch_size: int = 32, name: str = 'train'):
    """
    Parameters
    --------
    image_list: list (str, ...)
        読み込む画像のパスが含まれるリスト
    batch_size: int, default = 32
        バッチサイズ
    name: str, default = `train`
        データセットの名前。trainだと、シャッフルされる。
    """
    ds = tf.data.Dataset.from_tensor_slices(image_list)

    ds = ds.map(read_and_preprocess, num_parallel_calls=AUTOTUNE)
    if name == 'train':
        ds = ds.shuffle(2048)
    ds = ds.batch(batch_size)

    # check size
    size = ds.cardinality().numpy()
    if size > 0:
        print(f'{name} steps per epoch: {size}')

    ds = ds.prefetch(AUTOTUNE)
    return ds

例:

p = glob.glob('samples/*') # 適当なフォルダを指定
dataset = generate_dataset(image_list = p, batch_size = 2) # データセット生成

x = iter(dataset)
print(next(x))

出力:

train steps per epoch: 15

<tf.Tensor: shape=(2, 128, 128, 3), dtype=float32, numpy=
array([[[[0.9490196 , 0.93333334, 0.92941177],
         [0.9529412 , 0.9372549 , 0.93333334],
         [0.95686275, 0.9411765 , 0.9372549 ],
         ...,
         [0.9137255 , 0.91764706, 0.89411765],
         [0.90588236, 0.9098039 , 0.8862745 ],
         [0.9019608 , 0.90588236, 0.88235295]]]], dtype=float32)>

これでデータセットができた。あとはkerasのmodel.fitに入れるなり、train loopで使うなり。

TFRecord

こちらは、「ネットワーク経由で、画像を1枚1枚読み込むのは非効率なため、バッチ毎にデータを取得して読み込みたい」といったことをしたいときに使える(ストレージサーバからのデータ取得等)。分散学習(Tensorflow Federated)だったり、TPUだと良く使われるらしい。

データの書き出し

前処理とシリアライズ

まずは前処理とシリアライズ用の関数を準備。今回はSemantic Segmentationなどで使われることを想定して書く(画像とマスクがセットになる)。
以下は例で、前処理をカスタマイズしたい場合は、preprocess_imagepreprocess_maskを好きに変える。学習はしないので、型はnumpyでもok。よってalubumentationやcv2等が使える。

# 学習はしないので、tf.functionはつけなくて良い

def as_tensor(func):
    def function(*args, **kwargs) -> tf.Tensor:
        ret = func(*args, **kwargs)
        if not isinstance(ret, tf.Tensor):
            return tf.convert_to_tensor(ret)
        return ret
    return function

@as_tensor
def preprocess_image(img_path: Union[tf.Tensor, str], resize_shape: Tuple[int, int]=(256,256)):
    """
    画像の前処理関数
    """
    img = tf.io.read_file(img_path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.resize(img, resize_shape)
    # ここに水増し用の関数を入れる
    # e.g.
    # img = transforms(image=img.numpy())['image']
    img = tf.convert_to_tensor(img, tf.float32)
    img /= 255. # 水増し用の関数で正規化した場合、ここは不要
    return img

@as_tensor
def preprocess_mask(img_path: Union[tf.Tensor, str], resize_shape: Tuple[int, int]=(256,256)):
    """
    マスク画像の前処理関数
    """
    img = tf.io.read_file(img_path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.resize(img, resize_shape)
    img = tf.cast(img, tf.float32)
    img /= 255.
    return img

def serialize(base_path: tf.Tensor, mask_path: tf.Tensor):
    # 前処理とシリアライズ
    base_img = preprocess_image(base_path)
    base_img = tf.io.serialize_tensor(base_img)
    mask_img = preprocess_mask(mask_path)
    mask_img = tf.io.serialize_tensor(mask_img)

    bytes_base = tf.train.BytesList(value=[base_img.numpy()])
    bytes_mask = tf.train.BytesList(value=[mask_img.numpy()])

    features = tf.train.Features(
        feature={
            'base': tf.train.Feature(bytes_list = bytes_base),
            'mask': tf.train.Feature(bytes_list = bytes_mask)
            }
    )
    proto = tf.train.Example(features=features)
    return proto.SerializeToString()

特に前処理をしないなら、こういうのでok.

def simply_load(img_path: Union[tf.Tensor, str]) -> tf.Tensor:
    """
    マスク画像の前処理関数
    """
    img = tf.io.read_file(img_path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    return img

def serialize(base_path: tf.Tensor, mask_path: tf.Tensor):
    # 前処理とシリアライズ
    base_img = simply_load(base_path)
    base_img = tf.io.serialize_tensor(base_img)
    mask_img = simply_load(mask_path)
    mask_img = tf.io.serialize_tensor(mask_img)
    ...

書き出し

バッチサイズを指定して、(全ファイル数/バッチサイズ)個のファイルを作成する。

公式チュートリアルでも使われているtf.data.experimental.TFRecordWriterは2.6.0以降Deprecatedになったようで、tf.io.TFRecordWriterか、tf.data.experimental.save/tf.data.experimental.loadを使ってやってくれとのこと。後者は仕組みがよくわかっていないので、今回は前者で行う。

def split_path(filepathes: List[str], batch_size: int) -> List[List[str]]:
    total = math.ceil(len(filepathes)/batch_size)
    ret = [filepathes[i*batch_size:(i+1)*batch_size] for i in range(total)]
    if len(ret[-1]) == 0:
        ret = ret[:-1]
    return ret

def write_to_tfrecord(
    image_list: List[str],
    mask_list: List[str],
    batch_size: int = 32,
    output_path: str = 'tfr_outputs',
    output_path_exist_ok: bool=False):
    """
    Parameters
    --------
    image_list: list (str, ...)
        読み込む画像のパスが含まれるリスト
    mask_list: list (str, ...)
        読み込む画像(マスク)のパスが含まれるリスト。image_listと同じ長さでなければならない。
    batch_size: int, default = 32
        バッチサイズ
    output_path: str, default = `tfr_outputs/`
        tfrecordファイルの保存先
    output_path_exist_ok: bool, default = `False`
        output_pathの重複作成を許可するかどうか。Falseかつすでにフォルダが存在する場合。エラーが出る。
    """
    if len(image_list) != len(mask_list):
        raise ValueError('The sizes of image_list and mask_list do not match({} vs {}).'.format(len(image_list), len(mask_list)))

    if len(image_list) == 0:
        raise ValueError('Empty dataset.')

    os.makedirs(output_path, exist_ok=output_path_exist_ok)

    filepathes = list(zip(image_list, mask_list))

    split_fs = split_path(filepathes, batch_size)
    size = len(split_fs)
    adjust_size = len(list(str(size)))
    for i, fs in enumerate(split_fs):
        output = os.path.join(output_path, 'batch_{}.tfrecord'.format(str(i+1).rjust(adjust_size, '0')))
        print('\rWriting {}/{} to {}...'.format(str(i+1).rjust(adjust_size), size, output), end='')
        with tf.io.TFRecordWriter(output) as writer:
            for targets in fs:
                writer.write(serialize(*targets))
    print('\nDone.')

例:

write_to_tfrecord(image_list = p, mask_list = p, batch_size = 2)

出力:

Writing 15/15 to tfr_outputs/batch_15.tfrecord...
Done.

データの読み込み

tf.recordのファイルを読み込んでdeserializeする。
ここで前処理したいなら、deserializeの最後で前処理する関数などを入れる。

def deserialize(proto):
    parsed = tf.io.parse_example(
        proto,
        {
            'base': tf.io.FixedLenFeature([], tf.string),
            'mask': tf.io.FixedLenFeature([], tf.string)
        })
    base_img = tf.io.parse_tensor(parsed['base'], out_type=tf.float32)
    mask_img = tf.io.parse_tensor(parsed['mask'], out_type=tf.float32)
    # 前処理したいならここにその操作を挿入
    return base_img, mask_img

def generate_dataset_from_tfrecord(path: str, batch_size: int = 32, name: str = 'train'):
    """
    Parameters
    --------
    path: str
        読み込むtfrecordがあるフォルダ
    batch_size: int, default = 32
        バッチサイズ
    name: str, default = `train`
        データセットの名前。trainだと、シャッフルされる。
    """
    file_list = [p for p in glob.glob(path + '/*') if os.path.isfile(p) and os.path.splitext(p)[1] == '.tfrecord']
    if len(file_list) == 0:
        raise ValueError('Empty Dataset.')
    else:
        print(f'{name} steps per epoch: {len(file_list)}')

    ds = tf.data.TFRecordDataset(file_list)
    ds = ds.map(deserialize, num_parallel_calls=AUTOTUNE).batch(batch_size)
    if name == 'train':
        ds = ds.shuffle(2048)
    ds = ds.prefetch(AUTOTUNE)
    return ds

例:

ds = generate_dataset_from_tfrecord('tfr_outputs', batch_size=2)
x = next(iter(ds))
print(len(x))
print(x[0])

出力:

train steps per epoch: 15
2
tf.Tensor(
[[[[0.775337   0.802788   0.8263174 ]
   [0.75114125 0.7785922  0.80212164]
   [0.7315334  0.7589844  0.7825138 ]
   ...
   [0.84117645 0.8372549  0.8215686 ]
   [0.84117645 0.8372549  0.8215686 ]
   [0.84432447 0.8404029  0.8247166 ]]]], shape=(2, 256, 256, 3), dtype=float32)

ちゃんと型が復元されていることが分かる。

6
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
0