LoginSignup
0
0

More than 3 years have passed since last update.

中身を理解!Chainerにおけるデータの準備・設定方法3 〜 chainer_datasets.get_mnist()と同等機能の関数を作成してみる 〜

Posted at

Chainerにおけるデータの準備・設定方法について書いています。

MNISTデータを学習するChainerのプログラムの説明はよくあるのですが、データの準備・設定方法については「train, test = datasets.get_mnist(ndim = 3)」の1行で終わっていてブラックボックスとなっていてよく分からないため、この記事を書いています。MNISTデータを学習するChainerのプログラム自身は理解していることを前提としています。以下の内容からなる全3回の3回目になります。

中身を理解!Chainerにおけるデータの準備・設定方法1
〜 MNISTデータをコピーしてみる(バイナリデータの操作) 〜
中身を理解!Chainerにおけるデータの準備・設定方法2
〜 JPEGファイルからMNISTと同じ仕様の画像データを作成してみる 〜
中身を理解!Chainerにおけるデータの準備・設定方法3
〜 chainer_datasets.get_mnist()と同等機能の関数を作成してみる 〜

第3回では、chainer_datasets.get_mnist(ndim = 3)と同等機能の関数で、イメージデータファイルとラベルデータファイルからタプルデータセットを作成するプログラムについて説明します。

ただし、chainer_datasets.get_mnist()の全ての仕様を実装している訳ではありません。ndim = 3で、残りのパラメータは全てデフォルトのグレースケールの画像データのみを対象としています。基本的にMNISTを理解するためのものだと思って下さい。実際の使い方としては、例えば、「train, test = datasets.get_mnist(dim = 3)」の代わりに、以下の様に使える関数(create_tuple_dataset_from_image())のプログラムについて説明します。

train = create_tuple_dataset_from_image('tmp/', 'train-images-idx3-ubyte', 'tmp/', 'train-labels-idx1-ubyte')
test =  create_tuple_dataset_from_image('tmp/', 't10k-images-idx3-ubyte', 'tmp/', 't10k-labels-idx1-ubyte')

また、第2回で作成したプログラム(create_image_dataset.py)が作成したイメージデータファイル(outimagefile)、ラベルデータファイル(outlabelfile)をテストデータに使うのであれば、以下の様になります。なお、MNISTデータセットの手書き数字は黒地に白文字です。訓練データにはMNISTを使い、テストデータのみ自分で手書き数字を用意する場合は黒地に白文字の手書き数字を用意すべきでしょう。

train = create_tuple_dataset_from_image('tmp/', 'train-images-idx3-ubyte', 'tmp/', 'train-labels-idx1-ubyte')
test =  create_tuple_dataset_from_image('tmp/', ‘outimagefile’, 'tmp/', ‘outlabelfile’)

importは、以下の通りです。matplotlib.pyplotは、グラフ出力モジュールです。

import numpy as np
import chainer
from chainer import cuda, Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.datasets import tuple_dataset
import matplotlib.pyplot as plt

最初に関数create_tuple_dataset_from_image()を定義します。4個の引数は、イメージデータファイルのパス名、ファイル名、ラベルデータファイルのパス名、ファイル名です。verboseを1にすると、内容を確認できるprintが表示されます。verboseを1より大きい値に設定すると、イメージデータファイルの最初のイメージをプロットし、最初のラベルを表示します。

#
#create tuple dataset from image and label files
#

def create_tuple_dataset_from_image(imagepath, imagefile, labelpath, labelfile):
    verbose = 2

    imagepathfile = imagepath + imagefile
    labelpathfile = labelpath + labelfile

イメージデータファイルとラベルデータファイルをバイナリモード・READモードでオープンします(’rb’)。

    #
    #create TupleDataset
    #
    print('create tuple dataset start')
    imagef = open(imagepathfile, 'rb')
    labelf = open(labelpathfile, 'rb')

イメージデータファイルの先頭に入っている識別子、データ数、1画像当たりのデータ行数、1画像当たりのデータ列数を読み込みます。バイナリデータなのでread(4)で4バイトをREADし、それをint.from_bytes()でINTに変換します。

    bdataid = imagef.read(4)
    bdatanum = imagef.read(4)
    blinenum = imagef.read(4)
    brownum = imagef.read(4)
    if (verbose): print('bdataid = ', bdataid)
    if (verbose): print('bdatanum = ', bdatanum)
    if (verbose): print('blinenum = ', blinenum)
    if (verbose): print('brownum = ', brownum)

    dataid = int.from_bytes(bdataid, byteorder='big')
    datanum = int.from_bytes(bdatanum, byteorder='big')
    linenum = int.from_bytes(blinenum, byteorder='big')
    rownum = int.from_bytes(brownum, byteorder='big')
    if (verbose): print('dataid = ', dataid)
    if (verbose): print('datanum = ', datanum)
    if (verbose): print('linenum = ', linenum)
    if (verbose): print('rownum = ', rownum)

同様に、ラベルデータファイルの先頭に入っている識別子、データ数をを読み込みます。バイナリデータなのでread(4)で4バイトをREADし、それをint.from_bytes()でINTに変換します。

   bdataid_label = labelf.read(4)
    bdatanum_label = labelf.read(4)
    if (verbose): print('bdataid_label = ', bdataid_label)
    if (verbose): print('bdatanum_label = ', bdatanum_label)

    dataid_label = int.from_bytes(bdataid_label, byteorder='big')
    datanum_label = int.from_bytes(bdatanum_label, byteorder='big')
    linenum = int.from_bytes(blinenum, byteorder='big')
    if (verbose): print('dataid_label = ', dataid_label)
    if (verbose): print('datanum_label = ', datanum_label)

イメージデータファイルのデータ数とラベルデータファイルのデータ数が等しいことを確認します。

    if (datanum != datanum_label):
        print('numbers of image and label are different!')

イメージデータとラベルデータを1個づつ、全部でdatanum個処理します。1個のイメージデータは、linenum * rownum個のバイトデータからなるので、それを1バイトづつ読み出し、INTに変換し、255.で割ってfloatに変換します(fpixel)。それをリストfpixellistにappend()します。

1個のラベルデータは、1バイトのデータなので、それを1バイトづつ読み出し、INTに変換します(ilabel)。それをリストilabellistにappend()します。

全てのイメージデータとラベルデータを処理したら、まず、イメージデータのリストをNumPy配列にし、次元を変換します(reshape(datanum, 1, linenum, rownum))。変換したイメージデータのNumPy配列とラベルデータのリストを使ってタプルデータセットを作成します(tupledataset = tuple_dataset.TupleDataset(fpixelnparray, ilabellist))。

    fpixellist = []
    ilabellist = []
    for i in range(datanum):
        #read images
        for j in range(linenum * rownum):
            bpixel = imagef.read(1)
            fpixel = int.from_bytes(bpixel, byteorder='big') / 255.
            fpixellist.append(fpixel)

        #read labels
        blabel = labelf.read(1)
        ilabel = int.from_bytes(blabel, byteorder='big')
        ilabellist.append(ilabel)

    fpixelnparray = np.array(fpixellist, dtype = np.float32).reshape(datanum, 1, linenum, rownum)
    tupledataset = tuple_dataset.TupleDataset(fpixelnparray, ilabellist)

Verboseが1より大きい場合は、イメージデータファイルの最初のイメージをプロットし、最初のラベルを表示します。

    if (verbose > 1):
        aimage = np.array(fpixelnparray[0])
        plt.imshow(aimage.reshape([28, 28]), cmap='gray')
        plt.axis('off')
        plt.show()
        print('label = ', ilabellist[0])

イメージファイルとラベルファイルをクローズし、作成したタプルデータセットtupledatasetを戻り値にリターンします。

    imagef.close()
    labelf.close()

    print('create tuple dataset complete')

    return tupledataset

第2回で作成したcreate_image_dataset.pyのプログラムで、自分の手書き数字のイメージデータファイルとラベルデータファイルを作り、色々なところに掲載されているChainerのMNISTニューラルネットプログラムを以下の様に書き換えれば、自分の手書き数字でテストを行えます。

ただし、MNISTデータセットの手書き数字は黒地に白文字です。訓練データにはMNISTを使い、テストデータのみ自分で手書き数字を用意する場合は黒地に白文字の手書き数字の方が良さそうです。

train = create_tuple_dataset_from_image('tmp/', 'train-images-idx3-ubyte', 'tmp/', 'train-labels-idx1-ubyte')
test =  create_tuple_dataset_from_image('tmp/', ‘outimagefile’, 'tmp/', ‘outlabelfile’)

勿論、自分の手書き数字で訓練することも可能です。学習に十分な数の手書き数字を用意さえすればですが。Chainerにおけるデータの準備・設定方法を中身が理解できるよう、分かり易く書いてみたつもりです。非効率なところ等ありますがご容赦下さい。

リスト

create_tupedataset_from_image_func.py
import numpy as np
import chainer
from chainer import cuda, Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.datasets import tuple_dataset
import matplotlib.pyplot as plt

#
#create tuple dataset from image and label files
#

def create_tuple_dataset_from_image(imagepath, imagefile, labelpath, labelfile):
    verbose = 2

    imagepathfile = imagepath + imagefile
    labelpathfile = labelpath + labelfile

    #
    #create TupleDataset
    #
    print('create tuple dataset start')
    imagef = open(imagepathfile, 'rb')
    labelf = open(labelpathfile, 'rb')

    bdataid = imagef.read(4)
    bdatanum = imagef.read(4)
    blinenum = imagef.read(4)
    brownum = imagef.read(4)
    if (verbose): print('bdataid = ', bdataid)
    if (verbose): print('bdatanum = ', bdatanum)
    if (verbose): print('blinenum = ', blinenum)
    if (verbose): print('brownum = ', brownum)

    dataid = int.from_bytes(bdataid, byteorder='big')
    datanum = int.from_bytes(bdatanum, byteorder='big')
    linenum = int.from_bytes(blinenum, byteorder='big')
    rownum = int.from_bytes(brownum, byteorder='big')
    if (verbose): print('dataid = ', dataid)
    if (verbose): print('datanum = ', datanum)
    if (verbose): print('linenum = ', linenum)
    if (verbose): print('rownum = ', rownum)

    bdataid_label = labelf.read(4)
    bdatanum_label = labelf.read(4)
    if (verbose): print('bdataid_label = ', bdataid_label)
    if (verbose): print('bdatanum_label = ', bdatanum_label)

    dataid_label = int.from_bytes(bdataid_label, byteorder='big')
    datanum_label = int.from_bytes(bdatanum_label, byteorder='big')
    linenum = int.from_bytes(blinenum, byteorder='big')
    if (verbose): print('dataid_label = ', dataid_label)
    if (verbose): print('datanum_label = ', datanum_label)

    if (datanum != datanum_label):
        print('numbers of image and label are different!')

    fpixellist = []
    ilabellist = []
    for i in range(datanum):
        #read images
        for j in range(linenum * rownum):
            bpixel = imagef.read(1)
            fpixel = int.from_bytes(bpixel, byteorder='big') / 255.
            fpixellist.append(fpixel)

        #read labels
        blabel = labelf.read(1)
        ilabel = int.from_bytes(blabel, byteorder='big')
        ilabellist.append(ilabel)

    fpixelnparray = np.array(fpixellist, dtype = np.float32).reshape(datanum, 1, linenum, rownum)
    tupledataset = tuple_dataset.TupleDataset(fpixelnparray, ilabellist)

    if (verbose > 1):
        aimage = np.array(fpixelnparray[0])
        plt.imshow(aimage.reshape([28, 28]), cmap='gray')
        plt.axis('off')
        plt.show()
        print('label = ', ilabellist[0])

    imagef.close()
    labelf.close()

    print('create tuple dataset complete')

    return tupledataset
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0