はじめに
tensorflow.data.Datasetを使っていて、preprocess内でnumpyのFFTやopenCVを使ったらエラーが出て困った人向けです。いや自分自身が困ったので、そのTipです。
時間がない人向け
tf.py_functionを使用すると出来る。
細かい説明
tf.data.Datasetの簡単な動き
こちらはtensorflowの公式のドキュメントで書いてある通りに、datasetのloaderを作ります。pytorchでいえばdataloaderですね。
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192, 192])
return image
def load_and_preprocess_image(path):
image = tf.io.read_file(path)
return preprocess_image(image)
path_ds = tf.data.Dataset.from_tensor_slices(path_list)
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
label_ds = tf.data.Dataset.from_tensor_slices(label)
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
やっていることは簡単で、pathを渡すと、preprocessを終えたデータが読み込まれるというよくあるdataloaderを作っています。事前に画像のPathのリストと、紐づくラベルのリストを作っておけば簡単です。
これを実行してみると、下記のようなサンプルとなります。ちなみに画像はCifar10です。
fig = plt.figure(figsize=(10, 10))
for i, (im, l) in enumerate(image_label_ds.take(9)):
fig.add_subplot(3, 3, i + 1)
plt.imshow(im.numpy().astype(np.uint8))
plt.title(annot[l.numpy()]) # annotはラベル名のdictionary
plt.tight_layout()
凝ったPreProcessをしたい
このpreprocess内で少し凝った処理を走らせましょう。今回は例として、openCVを使って、色空間を変換(RGB->YCrCb)しています。別にNumpyのFFT関数でもいいし、まぁtensorflow純正以外の関数を使うなら何でも良いです。これで実行してみます。
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192, 192])
image = cv2.cvtColor(image.numpy(), cv2.COLOR_RGB2YCrCb)
image = tf.convert_to_tensor(image)
return image
なんかものすごいエラーが出ます。エラー内容が見たい人は左の三角をクリック。
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/workspace/project/cifar10/load_cifar10.py in
58 AUTOTUNE = tf.data.experimental.AUTOTUNE
59 path_ds = tf.data.Dataset.from_tensor_slices(path_list)
----> 60 image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
61
62 label_ds = tf.data.Dataset.from_tensor_slices(label)
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
1700 num_parallel_calls,
1701 deterministic,
-> 1702 preserve_cardinality=True)
1703
1704 def flat_map(self, map_func):
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
4082 self._transformation_name(),
4083 dataset=input_dataset,
-> 4084 use_legacy_function=use_legacy_function)
4085 if deterministic is None:
4086 self._deterministic = "default"
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
3369 with tracking.resource_tracker_scope(resource_tracker):
3370 # TODO(b/141462134): Switch to using garbage collection.
-> 3371 self._function = wrapper_fn.get_concrete_function()
3372 if add_to_graph:
3373 self._function.add_to_graph(ops.get_default_graph())
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
2937 """
2938 graph_function = self._get_concrete_function_garbage_collected(
-> 2939 *args, **kwargs)
2940 graph_function._garbage_collector.release() # pylint: disable=protected-access
2941 return graph_function
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
2904 args, kwargs = None, None
2905 with self._lock:
-> 2906 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2907 seen_names = set()
2908 captured = object_identity.ObjectIdentitySet(
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3211
3212 self._function_cache.missed.add(call_context_key)
-> 3213 graph_function = self._create_graph_function(args, kwargs)
3214 self._function_cache.primary[cache_key] = graph_function
3215 return graph_function, args, kwargs
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3073 arg_names=arg_names,
3074 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075 capture_by_value=self._capture_by_value),
3076 self._function_attributes,
3077 function_spec=self.function_spec,
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
984 _, original_func = tf_decorator.unwrap(python_func)
985
--> 986 func_outputs = python_func(*func_args, **func_kwargs)
987
988 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
3362 attributes=defun_kwargs)
3363 def wrapper_fn(*args): # pylint: disable=missing-docstring
-> 3364 ret = _wrapper_helper(*args)
3365 ret = structure.to_tensor_list(self._output_structure, ret)
3366 return [ops.convert_to_tensor(t) for t in ret]
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
3297 nested_args = (nested_args,)
3298
-> 3299 ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
3300 # If `func` returns a list of tensors, `nest.flatten()` and
3301 # `ops.convert_to_tensor()` would conspire to attempt to stack
~/.pyenv/versions/py37_emopy/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
256 except Exception as e: # pylint:disable=broad-except
257 if hasattr(e, 'ag_error_metadata'):
--> 258 raise e.ag_error_metadata.to_exception(e)
259 else:
260 raise
AttributeError: in user code:
<ipython-input-87-550493cbb714>:12 load_and_preprocess_image *
return preprocess_image(image)
<ipython-input-93-6034e2f0f0fd>:4 preprocess_image *
image = cv2.cvtColor(image.numpy(), cv2.COLOR_RGB2YCrCb)
AttributeError: 'Tensor' object has no attribute 'numpy'
ちなみにこの色変換の関数そのものには異常はありません。TensorflowのDatasetのようなパイプラインを介さずに利用すると、ちゃんと画像が表示できます。
image = tf.io.read_file(image_path)
image = preprocess_image(image)
fig = plt.figure(figsize=(8, 4))
fig.add_subplot(1, 2, 1)
plt.imshow(image.numpy().astype(np.uint8))
plt.title("YCC")
fig.add_subplot(1, 2, 2)
# 表示用にもう一度YCC→RGBに戻している
image = cv2.cvtColor(image.numpy(), cv2.COLOR_YCrCb2RGB)
plt.imshow(image.astype(np.uint8))
plt.title("RGB")
tf.py-function降臨
こういう事態を避ける方法があって、tf.py_functionを使用します。下記の例のように、tensorflow純正の関数ではない変換を別の関数にしてやり、それをラッピングする形で使用します。簡単ですね。
def trans_color(image):
image = cv2.cvtColor(image.numpy(), cv2.COLOR_RGB2YCrCb)
return tf.convert_to_tensor(image)
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192, 192])
image = tf.py_function(trans_color, [image], tf.float32)
return image
これで実際に動かしてみると、下記画像が得られました。
右上の画像と前項での画像が同じになっているのでうまく動いていそうです。
まとめ
tensorflowの前処理で、自身がよく知っているnumpyのFFTを使いたいとか、画像処理の定番のopenCVを使いたいとか、そういうことのTipsでした。py_functionはもっと奥深い機能だと思いますが、こんなふうにも使えるんですね。