2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

MVTecAD データセットローダー TensorFlow

Last updated at Posted at 2022-01-09

#はじめに

MVTecADデータセットローダープログラム。

#プログラム①:MVTecAD クラス

Data_Loader.py
# MVTec AD Data Loader

# -----------------------------------------------
# Import
# -----------------------------------------------
# 01. Built-in
import os

# 02. Third-party modules
import cv2
import numpy as np
import tensorflow as tf




# -----------------------------------------------
# Class
# -----------------------------------------------

class MVtecADLoader():
    def __init__(self, img_size = 224):

        # Datasets Path
      
        path_notebook = 'C:/Users/user_name/'
        path_nokori = 'Google ドライブ/datasets/mvtec_anomaly_detection'
        self.base_path = os.path.join(path_notebook, path_nokori)

        # Image Resolution
        self.img_size = img_size

        # number
        self.num_train = None
        # Category NG List

        self.category = {'bottle': ['good', 'broken_large', 'broken_small', 'contamination'],
                         'cable': ['good', 'bent_wire', 'cable_swap', 'combined', 'cut_inner_insulation',
                                   'cut_outer_insulation',
                                   'missing_cable', 'missing_wire', 'poke_insulation'],
                         'capsule': ['good', 'crack', 'faulty_imprint', 'poke', 'scratch', 'squeeze'],
                         'carpet': ['good', 'color', 'cut', 'hole', 'metal_contamination', 'thread'],
                         'grid': ['good', 'bent', 'broken', 'glue', 'metal_contamination', 'thread'],
                         'hazelnut': ['good', 'crack', 'cut', 'hole', 'print'],
                         'leather': ['good', 'color', 'cut', 'fold', 'glue', 'poke'],
                         'metal_nut': ['good', 'bent', 'color', 'flip', 'scratch'],
                         'pill': ['good', 'color', 'combined', 'contamination', 'crack', 'faulty_imprint', 'pill_type',
                                  'scratch'],
                         'screw': ['good', 'manipulated_front', 'scratch_head', 'scratch_neck', 'thread_side',
                                   'thread_top'],
                         'tile': ['good', 'crack', 'glue_strip', 'gray_stroke', 'oil', 'rough'],
                         'toothbrush': ['good', 'defective'],
                         'transistor': ['good', 'bent_lead', 'cut_lead', 'damaged_case', 'misplaced'],
                         'wood': ['good', 'color', 'combined', 'good', 'hole', 'liquid', 'scratch'],
                         'zipper': ['good', 'broken_teeth', 'combined', 'fabric_border', 'fabric_interior', 'rough',
                                    'split_teeth', 'squeezed_teeth']}



    def load(self, category, repeat = 4, max_rot = 10):

        #---------------------------------------------
        # 01. Load Train Dataset (OK)
        # ---------------------------------------------
        #data, mask, binary anomaly label ( 0 for anomaly, 1 for good)
        x, y, z = [], [], []

        path = os.path.join(os.path.join(self.base_path, category), 'train/good')
        files = os.listdir(path)

        zero_mask = np.zeros(shape = (self.img_size, self.img_size), dtype = np.float32)

        for rdx in range(repeat):
            for file in files:
                # x, data
                img_full_path = os.path.join(path, file)
                img = self._read_image(img_path= img_full_path) #(RGB)

                if not max_rot == 0:
                    img = tf.keras.preprocessing.image.random_rotation(img, max_rot)

                # y, mask
                img_mask = zero_mask
                # z, binary ( 0 for anomaly, 1 for good)
                z_element = 1

                x.append(img)
                y.append(img_mask)
                z.append(z_element)

        x = np.asarray(x)
        y = np.asarray(y)
        z = np.asarray(z)
        self.num_train = len(x)

        x = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(x, dtype = tf.float32))
        y = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(y, dtype = tf.int32))
        z = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(z, dtype = tf.int32))


        self.train = tf.data.Dataset.zip((x,y,z))

        # ---------------------------------------------
        # 02. Load Test Dataset (OK, NG)
        # ---------------------------------------------

        #data, mask, binary anomaly label ( 0 for anomaly, 1 for good)
        x, y, z = [], [], []

        for label in self.category[category]:
            path = os.path.join(os.path.join(self.base_path, category),'test/{}'.format(label))
            files = os.listdir(path)

            for file in files:
                # x, data
                img_full_path = os.path.join(path, file)
                img = self._read_image(img_path = img_full_path)

                if label =='good':
                    mask = zero_mask
                    idx_z = 1 #label = 1 for good
                else:
                    mask_path = os.path.join(os.path.join(self.base_path, category), 'ground_truth/{}'.format(label))
                    mask_path_file = os.path.join(mask_path, '{}_mask.png'.format(file.split('.')[0]))

                    mask = self._read_image(img_path= mask_path_file, flags = cv2.IMREAD_GRAYSCALE)
                    mask = mask/255
                    idx_z = 0  # label = 0 for NG
                # mask = mask[16:-16, 16:-16]
                mask = tf.convert_to_tensor(mask, dtype = tf.int32)


            x.append(img)
            y.append(img_mask)
            z.append(int(idx_z))

        x = np.asarray(x)
        y = np.asarray(y)
        z = np.asarray(z)
        self.num_test = len(x)


        x = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(x, dtype = tf.float32))
        y = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(y, dtype = tf.int32))
        z = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(z, dtype = tf.int32))

        self.test = tf.data.Dataset.zip((x, y, z))


    def _read_image(self, img_path , flags=cv2.IMREAD_COLOR):

        n = np.fromfile(img_path, dtype = np.uint8)
        img = cv2.imdecode(n, flags)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, dsize = (self.img_size, self.img_size))
        # img = img[16:-16, 16:-16, :] #why? rotateするから?

        return img

#プログラム②:テストプログラム

test.py

# -----------------------------------------------
# Import
# -----------------------------------------------
# 01. Built-in
import os

# 02. Third-party modules
import cv2
import numpy as np
import tensorflow as tf

# 03. My Module
from Data_Loader import MVtecADLoader


# Variables
img_size = 224

#Instance
loader= MVtecADLoader(com = 'note', img_size = img_size)

#------------------------------------------------------
# Data Load
#------------------------------------------------------
loader.load(category = 'bottle')

#Train Data and Test Data

train_set = loader.train.batch(batch_size = 16, drop_remainder = True).shuffle(buffer_size = loader.num_train,
                                                                                    reshuffle_each_iteration = True)
test_set = loader.test.batch(batch_size = 1, drop_remainder = False)

print('train set')
for x, y, z in train_set:

    print(x.shape)
    print(y.shape)
    print(z.shape)

print('test set')
for x, y, z in test_set:
    print(x.shape)
    print(y.shape)
    print(z.shape)
結果
train set
(16, 224, 224, 3) #x 
(16, 224, 224) #y
(16,) #z
...

test set
(16, 224, 224, 3)
(16, 224, 224)
(16,)
...

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?