LoginSignup
0
1

More than 3 years have passed since last update.

MNISTデータをランダムにサンプリングしてデータセットを作成する

Last updated at Posted at 2020-11-19

#概要
 MNISTデータセット全てではなく、MNISTの一部を使って学習をさせる必要が出てきました。そこでMNISTのTrainingデータ60000枚からランダムにn枚抽出して、クラス毎にフォルダ分けをして画像を保存するプログラムを作成しました。

#実行環境
Google Colaboratory
PyTorch 1.6.0

#実装
##MNISTを画像形式で保存
Trainデータセットからランダムに抽出できるようにMNISTデータセットをダウンロードし、画像形式で保存します。
こちらのサイトを参考にさせていただきました。
PyTorchでImageFolderを使ってみる

まずは必要なモジュールのimportから

import os
from PIL import Image
from torchvision.datasets import MNIST
import shutil
import glob
from pprint import pprint
import random
from pathlib import Path
from tqdm import tqdm

必要なモジュールがない場合は適宜pipやcondaでインストールしてください。

続いてMNISTをダウンロードします。

mnist_data = MNIST(root='./', train=True, transform=None, download=True)

mnistをダウンロードした際にUserWarningが出るかもしれませんが、今回はダウンロードしたmnistを使って学習を行うわけでは無いので気にしないでください。

ダウンロードしたMNISTのバイナリファイルからPNG形式でMNIST画像を保存します。

def makeMnistPng(image_dsets):
    for idx in tqdm(range(10)):
        print("Making image file for index {}".format(idx))
        num_img = 0
        dir_path = './mnist_all/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        for image, label in image_dsets:
           if label == idx:
                filename = dir_path +'/mnist_'+ str(idx) + '-' + str(num_img) + '.png'
                if not os.path.exists(filename):
                    image.save(filename)
                num_img += 1
    print('Success to make MNIST PNG image files. index={}'.format(idx))

関数を実行します。

makeMnistPng(mnist_data)

 これでmnist_all下にmnistの画像60000万枚全てが保存されました。クラス毎に画像を保存したい場合は以下のようにしてください。

def makeMnistPng(image_dsets):
    for idx in tqdm(range(10)):
        print("Making image file for index {}".format(idx))
        num_img = 0
        dir_path = './MNIST_PNG/' + str(idx)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        for image, label in image_dsets:
            if label == idx:
                filename = dir_path +'/' + 'mnist_'+ str(idx) + '_' + str(num_img) + '.png'
                if not os.path.exists(filename):
                    image.save(filename)
                num_img += 1
    print('Success to make MNIST PNG image files. index={}'.format(idx))

##ディレクトリ内のファイルからランダムにサンプリングする
 mnistの全てのデータを1つのディレクトリに落とし込むことができたのでそこからランダムにn枚の画像をサンプリングし、別ディレクトリにコピーしていきます。
 参考にさせていただいた(ほぼそのまま使用した)記事はこちら

###クラスの定義


class FileControler(object):
    def get_file_path(self, input_dir, pattern):
        #ファイルパスの取得
        #ディレクトリを指定しパスオブジェクトを生成
        path_obj = Path(input_dir)
        #glob形式でファイルをマッチ
        files_path = path_obj.glob(pattern)
        #文字列として扱うためposix変換
        files_path_posix = [file_path.as_posix() for file_path in files_path]
        return files_path_posix
    
    def random_sampling(self, files_path, sample_num, output_dir, fix_seed=True) -> None:
        #ランダムサンプリング
        #毎回同じファイルをサンプリングするにはSeedを固定する
        if fix_seed is True:
            random.seed(0)
        #ファイル群のパスとサンプル数を指定
        files_path_sampled = random.sample(files_path, sample_num)
        #出力先ディレクトリがなければ作成
        os.makedirs(output_dir, exist_ok=True)
        #コピー
        for file_path in files_path_sampled:
            shutil.copy(file_path, output_dir)

###インスタンス作成

file_controler =FileControler()

###ディレクトリの設定
サンプリング元のディレクトリとサンプリングしたファイルをコピーするディレクトリを設定します。

all_file_dir = './mnist_all/'
sampled_dir = './mnist_sampled/'

###全てのファイルのパスを取得

pattern = '*.png'
files_path = file_controler.get_file_path(all_file_dir, pattern)

print(len(files_path))
# 60000

###n枚サンプリング

sample_num = 100
file_controler.random_sampling(files_path, sample_num, sampled_dir)

sampled_files_path = file_controler.get_file_path(sampled_dir, pattern)
print(len(sampled_files_path))
# 100

 これでmnist60000枚の中からランダムにn枚(今回は100枚)サンプリングされました。

###クラス分け
 機械学習のデータセットとして使用できるよう、サンプリングされた画像たちをクラス毎に分けていきます。

まずサンプリングしたディレクトリ内のファイル名をリスト形式で全て取得します。

files = glob.glob("./mnist_sampled/*")

ファイル名のリストに対してin演算子を用いて部分文字列の判定を行いクラス毎にフォルダわけしていきます。

for i in range(10):
    os.makedirs(sampled_dir+str(i), exist_ok=True)
    for x in files:
        if '_' + str(i) in x:
            shutil.move(x, sampled_dir + str(i))

サンプリングしたディレクトリはこのようなディレクトリ構成になります。

./mnist_sampled
├── 0
├── 1
├── 2
├── 3
├── 4
├── 5
├── 6
├── 7
├── 8
└── 9

これでmnistの画像をランダムにサンプリングし、それらをクラス分けをしてデータセットを作成することができました。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1