1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[TensorFlow]入力パイプラインでできることまとめ

Posted at

はじめに

機械学習モデルを構築する際、データは前処理や拡張等の何らかの処理を加えられることがある。Tensorflowにはそういった処理プロセスを効率的に処理するために入力パイプラインが存在する。今回は入力パイプラインでできることの一例をまとめていきたいと思う

入力パイプラインとは?

機械学習モデルに入力されるデータを効率的に処理し、供給するための一連のステップである。入力パイプラインには、データの読み込み、前処理、拡張、バッチ処理などが含まれる。

入力パイプライン

Tensorflowで入力パイプラインを構築するためには、tf.data APIを使用する。
基本的な仕組みは、まずはデータを基にDatasetオブジェクトを作成し、Datasetオブジェクトのメソッドを用いて新たなDatasetオブジェクトを生成する。例えば、メソッドにはDataset.map
Dataset.batchがある。

それでは、入力パイプラインでできることをまとめていく。

入力データの読み取り

データを入力し、Datasetオブジェクトを生成する。様々なデータからDatasetオブジェクトを生成することができ、例えば以下のようなものを扱うことができる。

  • list
  • Numpy配列
  • TFRecord データ
  • テキストデータ
  • CSV
  • ファイルセット

コードは以下の通り。

# list
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])

# Numpy
dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3]))

# 説明変数と目的変数をセットにしてDatasetを作成することも可能
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))

# TFRecord
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])

# テキストデータ
directory_url = 'hogehoge/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']
file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
dataset = tf.data.TextLineDataset(file_paths)

# CSV
csv_file = tf.keras.utils.get_file("train.csv", "hogehoge/train.csv")
df = pd.read_csv(csv_file)
csv_slices = tf.data.Dataset.from_tensor_slices(dict(df))

# CSV(データがメモリに収まらない場合)
csv_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="fugafuga")

# ファイルセット
flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)
list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

バッチ処理

ニューラルネットワークに基づくモデルを構築する際、ミニバッチ学習がしばしば行われる。そんな時に使えるのがbatch()メソッドだ。
batchメソッドを使用することで、n個の連続するデータセットを1つの要素にまとめることができる。エポックごとに32個のデータでミニバッチを行いたいのであれば、dataset.batch(32)とすることでミニバッチ学習に使用できるデータセットを作成できる。

exp_val = tf.data.Dataset.range(100)
label = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((exp_val, label))
batched_dataset = dataset.batch(5)
for batch in batched_dataset.take(5):
    print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3, 4]), array([ 0, -1, -2, -3, -4])]
[array([5, 6, 7, 8, 9]), array([-5, -6, -7, -8, -9])]
[array([10, 11, 12, 13, 14]), array([-10, -11, -12, -13, -14])]
[array([15, 16, 17, 18, 19]), array([-15, -16, -17, -18, -19])]
[array([20, 21, 22, 23, 24]), array([-20, -21, -22, -23, -24])]

入力データのシャッフル

入力データをランダムシャッフルしたい場合は、Dataset.shuffle()変換が使用できる。
Dataset.shuffle()は引数にbuffer_sizeが必須である。ランダム抽出の処理手順は下記と理解した。

1.初期バッファ充填: buffer_size個の要素をdatasetの初めから取得
2.ランダム抽出: バッファにプールされた要素からランダムに要素を抽出
3. 補充: 抽出した分だけdatasetから値を取得(初期バッファ充填で選択されていない、一番初めの要素が取得される)

exp_val = tf.data.Dataset.range(100)
label = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((exp_val, label))
dataset = dataset.shuffle(buffer_size=10)
dataset = dataset.batch(5)
for batch in dataset.take(5):
  print([arr.numpy() for arr in batch])

データセットに関数を適用

Dataset.map(f)によって、指定された関数fをデータセットの各要素に適用できる。

def f(x, y):
    return x**2, y

exp_val = tf.data.Dataset.range(100)
label = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((exp_val, label))
dataset = dataset.map(f)
dataset = dataset.batch(5)
for batch in dataset.take(5):
  print([arr.numpy() for arr in batch])
[array([ 0,  1,  4,  9, 16]), array([ 0, -1, -2, -3, -4])]
[array([25, 36, 49, 64, 81]), array([-5, -6, -7, -8, -9])]
[array([100, 121, 144, 169, 196]), array([-10, -11, -12, -13, -14])]
[array([225, 256, 289, 324, 361]), array([-15, -16, -17, -18, -19])]
[array([400, 441, 484, 529, 576]), array([-20, -21, -22, -23, -24])]

不均衡データに対するリサンプリング

データセットが不均衡な場合、データセットをリサンプリングすることがある。事前に正例と負例のデータセットに分割する必要はあるが、下記でリサンプルすることが可能である。今回は正例/負例を50/50 の等しい確率で得られるようにしている。

balanced_ds = tf.data.Dataset.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

事前に正例と負例のデータセットに分割するのが面倒であれば、tf.data.Dataset.rejection_resampleによる棄却リサンプリングもある。棄却リサンプリングする場合は、class_func引数が必要で、どのカラムがラベルに属するのか判定する関数が必要である。

def class_func(features, label):
  return label

resample_ds = (
    creditcard_ds
    .unbatch()
    .rejection_resample(class_func, target_dist=[0.5,0.5],
                        initial_dist=[0.9, 0.1]) # リサンプリングする前の比率
    .batch(10))

おわりに

Tensorflowのtf.data APIを用いた入力パイプラインについて整理した。機械学習モデルを学習する際に行う、様々な処理がAPIとして用意されている。
詳しい説明はTensorflowのドキュメントにまとまっているため、不明点あれば参考にしてみてほしい。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?