LoginSignup
0
0

More than 3 years have passed since last update.

データセットのローダーの作成

Posted at

はじめに

データセット

class DataSet():
    """データセットの管理."""

    def __init__(self, images, labels):
        self._num_examples = images.shape[0]
        images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
        images = images.astype(numpy.float32)
        images = numpy.multiply(images, 1.0 / 255.0)
        self._images = images
        self._labels = labels
        self._epochs_completed = 0
        self._index_in_epoch = 0
def dense_to_one_hot(labels_dense, num_classes):
    """Convert class labels from scalars to one-hot vectors."""
    num_labels = labels_dense.shape[0]
    index_offset = numpy.arange(num_labels) * num_classes
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot

pickle データの読み込み

  • 上記のデータセットに pickle の画像、ラベルを読み込みます。

データの読み込み

  • 設定ファイルを元に、オリジナルのデータ、もしくは水増しデータを読み込みます。
def load_data(one_hot=False, validation_size=0):
    """データセットを config.py に従い読み込み."""

    train_num = AUGMENT_NUM if USE_AUGMENT else 0
    datasets_file = os.path.join(DATASETS_PATH, ','.join(CLASSES), '{}x{}-{}.pickle'.format(IMG_ROWS, IMG_COLS, train_num))

    with open(datasets_file, 'rb') as fin:
        (train_images, train_labels), (test_images, test_labels) = pickle.load(fin)

ラベルの one hot 変換

  • 今回は、読み込んだラベルデータは、one_hot に変換して後工程に繋げています。
  • 例えば、晴れ: 1、曇り: 2、雨: 3 の様にラベルが付いているとすると、晴れ: (1, 0, 0)、曇り: (0, 1, 0)、雨: (0, 0, 1) の様な形への変換です。
  • 詳しくは、機械学習 one hot で検索してください。
    if one_hot:
        num_classes = len(numpy.unique(train_labels))
        train_labels = dense_to_one_hot(train_labels, num_classes)
        test_labels = dense_to_one_hot(test_labels, num_classes)

画像とラベルのデータセット クラス化

  • データセットを読み込むたびに、位置をシャッフルしています。
  • バリデーションが必要な場合もあるので、学習データから切り出せる様にしています。
  • 最後に、学習、バリデーション、テストのデータセットクラス化をしています。
    perm = numpy.arange(train_images.shape[0])
    numpy.random.shuffle(perm)
    train_images = train_images[perm]
    train_labels = train_labels[perm]

    validation_images = train_images[:validation_size]
    validation_labels = train_labels[:validation_size]
    train_images = train_images[validation_size:]
    train_labels = train_labels[validation_size:]

    train = DataSet(train_images, train_labels)
    validation = DataSet(validation_images, validation_labels)
    test = DataSet(test_images, test_labels)

    return Datasets(train=train, validation=validation, test=test)

おわりに

  • データセットのローダーを作成しました。元データの、pickle データを読み込む箇所に改造を行なっています。
  • ただ、最近は、この辺りを隠蔽してプログラム出来るので、わざわざ実装する必要は少ないです。最初で最後だと思う。
  • 次回は、学習モデルを作成したいと思います。
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0