Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
0
Help us understand the problem. What is going on with this article?
@hirayama968

中身を理解!Chainerにおけるデータの準備・設定方法2 〜 JPEGファイルからMNISTと同じ仕様のイメージデータファイルを作成してみる 〜

More than 1 year has passed since last update.

Chainerにおけるデータの準備・設定方法について書いています。

MNISTデータを学習するChainerのプログラムの説明はよくあるのですが、データの準備・設定方法については「train, test = datasets.get_mnist(ndim = 3)」の1行で終わっていてブラックボックスとなっていてよく分からないため、この記事を書いています。MNISTデータを学習するChainerのプログラム自身は理解していることを前提としています。以下の内容からなる全3回の2回目になります。

中身を理解!Chainerにおけるデータの準備・設定方法1
〜 MNISTデータをコピーしてみる(バイナリデータの操作) 〜
中身を理解!Chainerにおけるデータの準備・設定方法2
〜 JPEGファイルからMNISTと同じ仕様の画像データを作成してみる 〜
中身を理解!Chainerにおけるデータの準備・設定方法3
〜 chainer_datasets.get_mnist()と同等機能の関数を作成してみる 〜

第2回では、JPEGファイルからMNISTと同じ仕様のグレースケールのイメージデータファイルとラベルデータファイルを作成するプログラムについて説明します。

importは、以下の通りです。Pillow(PIL)は、画像処理ライブラリです。matplotlib.pyplotは、グラフ出力モジュールです。

import numpy as np
import chainer
from chainer import cuda, Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from PIL import Image
import matplotlib.pyplot as plt

verboseを1にすると、内容を確認できるprintが表示されます。verboseを1より大きい値に設定すると、イメージデータファイルに書き込んだイメージをプロットし、ラベルを表示します。JPEGファイルは、28×28、グレースケールの画像データに変換してイメージデータファイルに書き込みます。fixed_w(28)、fixed_h(28)は、JPEGファイルを変換する幅、高さのピクセル数です。MNISTと同じです。識別子dataidは128にします。

verbose = 2
fixed_w , fixed_h = 28, 28
dataid = 128

イメージファイルは手書き数字をJPEGファイルとして保存します。第3回で説明しますが、MNISTデータセットの手書き数字は黒地に白文字です。イメージファイルは複数あるので、イメージファイルの名前を記述したイメージファイルのリストを用意します。例えばイメージファイルがn00.jpgからn19.jpgの20個あるなら、イメージファイルのリストの中身は以下のようにします。

n00.jpg
n01.jpg
n02.jpg
n03.jpg
…
n17.jpg
n18.jpg
n19.jpg

イメージファイルのリストのパス名をimageflistpathに、ファイル名をimageflistfileに設定します。また、イメージファイルそのもののパス名をimagepathに設定します。ここでは、イメージファイルそのものとイメージファイルのリストをフォルダtmp/numbersの下に置いています。

imageflistpath = 'tmp/numbers/'
imageflistfile = 'imageflist.txt'
imageflistpathfile = imageflistpath + imageflistfile
imagepath = 'tmp/numbers/'

作成するイメージデータファイルのパス名をoutimagepathに、ファイル名をoutimagefileに設定します。

outimagepath = 'tmp/'
outimagefile = 'outimagefile'
outimagepathfile = outimagepath + outimagefile

イメージファイルのラベルをラベルファイルに設定します。例えばイメージファイルがn00.jpgからn19.jpgの20個で、そのラベルが0、1、2、、、9、0、、、7、8、9の20個だとしたら、ラベルファイルの中身は以下のようにします(20行)。

0
1
2
3
…
7
8
9

ラベルファイルのパス名をlabelpathに、ファイル名をlabelfileに設定します。

labelpath = 'tmp/numbers/'
labelfile = 'label.txt'
labelpathfile = labelpath + labelfile

作成するラベルデータファイルのパス名をoutlabelpathに、ファイル名をoutlabelfileに設定します。

outlabelpath = 'tmp/'
outlabelfile = 'outlabelfile'
outlabelpathfile = outlabelpath + outlabelfile

最初にイメージデータファイルを作成します。イメージファイルのリストimageflistpathfileを、READモードでオープンします。イメージファイルのリストに記述されている全てのイメージファイル名をfns[]に読み込みます。この時、readline()で各行を読み込むため、行末に改行コードが付加されているので、rstrip(‘\n’)で削除しておきます。

#
#create image dataset
#
print('create image dataset start')
imageflist = open(imageflistpathfile, 'r')
filename = imageflist.readline()
fns = []
while filename:
    fns.append(filename.rstrip('\n'))
    filename = imageflist.readline()
imageflist.close()
if (verbose): print('fns = ', fns)

イメージデータファイルの先頭に記述する識別子、データ数、1画像当たりのデータ行数、1画像当たりのデータ列数をINTからバイナリデータに変換します。その後、イメージデータファイルoutimagepathfileをバイナリモード・WRITEモードでオープンし(’wb’)、バイナリデータに変換した識別子、データ数、1画像当たりのデータ行数、1画像当たりのデータ列数を書き込みます。(識別子、データ数は、後程、ラベルデータファイルにも書き込みます。)

bdataid = dataid.to_bytes(4, byteorder='big')
bdatanum = (len(fns)).to_bytes(4, byteorder='big')
blinenum = fixed_h.to_bytes(4, byteorder='big')
brownum = fixed_w.to_bytes(4, byteorder='big')
outimagef = open(outimagepathfile, 'wb')

outimagef.write(bdataid)
outimagef.write(bdatanum)
outimagef.write(blinenum)
outimagef.write(brownum)

イメージファイルを1個づつ画像処理ライブラリPillow(PIL)で読み込みます(image = Imege.open())。読み込んだimageの幅、高をimage.sizeで求めます(w、h)。幅の方が高さより長い場合は、幅(fixed_w * w // h)、高さ(fixed_h)に、resize()します。高さの方が幅より長い場合は、幅(fixed_w)、高さ(fixed_h * h // w)に、resize()します。その後、真ん中部分だけを抽出するように、crop()で幅(fixed_w)、高さ(fixed_h)にトリミングします。(幅の方が長い場合は左右の端部分が、高さの方が長い場合は上下の端部分がカットされます。)JPEGファイルはカラーですがMNISTデータはグレースケールなので、image.convert(‘L’)でグレースケールに変換します。幅fixed_w、高さfixed_h、グレースケールに変換されたimageを、型np.float32のNumPy配列にします(pixeldata)。

変換されたイメージ(pixeldata)を、1ピクセルづつINT型(ip)に変換し、それを更にバイト型(bp)に変換し(to_bytes())、1バイトづつイメージデータファイルに書き込みます(write())。1イメージ分のINT型のピクセルデータ(ip)をリストlpに格納しているのは、verboseが1より大きい場合に、そのイメージをプロットするためです。

この後、verboseが1より大きい場合は、イメージデータファイルに書き込んだイメージをプロットします。

for fn in fns:
    image = Image.open(imagepath + fn)
    w, h = image.size
    if (verbose > 1): print('image.size = w, h = ', w, h)
    if w > h:
        shape = (fixed_w * w // h, fixed_h)
    else:
        shape = (fixed_w, fixed_h * h // w)
    if (verbose> 1): print('shape = ', shape)
    left = (shape[0] - fixed_w) // 2
    top = (shape[1] - fixed_h) // 2
    right = left + fixed_w
    bottom = top + fixed_h
    if (verbose > 1): print('left, top, right, bottom = ', left, top, right, bottom)
    image = image.resize(shape)
    if (verbose > 1): print('image.size = ', image.size, ' (after resize())')
    image = image.crop((left, top, right, bottom))
    if (verbose > 1): print('image.size = ', image.size, ' (after crop())')
    image = image.convert('L')
    pixeldata = np.array(image).astype(np.float32)

    lp = []
    for y in range(fixed_h):
        for x in range(fixed_w):
            ip = int(pixeldata[y][x])
            bp = ip.to_bytes(1, byteorder='big')
            outimagef.write(bp)
            lp.append(float(ip))
    image.close()

    if (verbose > 1):
        ap = np.array(lp)
        plt.imshow(ap.reshape([28, 28]), cmap='gray')
        plt.axis('off')
        plt.show()

イメージデータファイルが完成したのでクローズします。

outimagef.close()

ラベルデータファイルを作成する処理も基本的に同じですが、もっと単純です。まず、ラベルファイルlabelpathfileをREADモードでオープンします。次に、ラベルデータファイルoutlabelpathfileをバイナリモード・WRITEモードでオープンし(’wb’)、識別子、データ数を書き込みます。

#
#create label dataset
#
print('create label dataset start')
labelf = open(labelpathfile, 'r')
outlabelf = open(outlabelpathfile, 'wb')
outlabelf.write(bdataid)
outlabelf.write(bdatanum)

ラベルは、ラベルファイルからreadline()で各行を読み込むため、行末に改行コードが付加されているので、rstrip(‘\n’)で削除しておきます。読み込んだラベル(’0’、’1’、…、’9’)をint()でINT型に変換します。verboseが1より大きい場合は、ラベルデータファイルに書き込むラベルを表示します。更にINT型のラベルをto_bytes(1, byteorder='big’)でバイトデータに変換し、ラベルデータファイルに書き込みます。

label = labelf.readline()
while label:
    label = label.rstrip('\n')
    label = int(label)
    if (verbose > 1): print('label = ', label)
    blabel = label.to_bytes(1, byteorder='big')
    outlabelf.write(blabel)
    label = labelf.readline()

ラベルデータファイルが完成したので、ラベルファイルと共にクローズします。

labelf.close()
outlabelf.close()

print('create complete')

リスト

create_image_dataset.py
import numpy as np
import chainer
from chainer import cuda, Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from PIL import Image
import matplotlib.pyplot as plt

verbose = 2
fixed_w , fixed_h = 28, 28
dataid = 128

imageflistpath = 'tmp/numbers/'
imageflistfile = 'imageflist.txt'
imageflistpathfile = imageflistpath + imageflistfile
imagepath = 'tmp/numbers/'

outimagepath = 'tmp/'
outimagefile = 'outimagefile'
outimagepathfile = outimagepath + outimagefile

labelpath = 'tmp/numbers/'
labelfile = 'label.txt'
labelpathfile = labelpath + labelfile

outlabelpath = 'tmp/'
outlabelfile = 'outlabelfile'
outlabelpathfile = outlabelpath + outlabelfile

#
#create image dataset
#
print('create image dataset start')
imageflist = open(imageflistpathfile, 'r')
filename = imageflist.readline()
fns = []
while filename:
    fns.append(filename.rstrip('\n'))
    filename = imageflist.readline()
imageflist.close()
if (verbose): print('fns = ', fns)

bdataid = dataid.to_bytes(4, byteorder='big')
bdatanum = (len(fns)).to_bytes(4, byteorder='big')
blinenum = fixed_h.to_bytes(4, byteorder='big')
brownum = fixed_w.to_bytes(4, byteorder='big')
outimagef = open(outimagepathfile, 'wb')

outimagef.write(bdataid)
outimagef.write(bdatanum)
outimagef.write(blinenum)
outimagef.write(brownum)

for fn in fns:
    image = Image.open(imagepath + fn)
    w, h = image.size
    if (verbose > 1): print('image.size = w, h = ', w, h)
    if w > h:
        shape = (fixed_w * w // h, fixed_h)
    else:
        shape = (fixed_w, fixed_h * h // w)
    if (verbose> 1): print('shape = ', shape)
    left = (shape[0] - fixed_w) // 2
    top = (shape[1] - fixed_h) // 2
    right = left + fixed_w
    bottom = top + fixed_h
    if (verbose > 1): print('left, top, right, bottom = ', left, top, right, bottom)
    image = image.resize(shape)
    if (verbose > 1): print('image.size = ', image.size, ' (after resize())')
    image = image.crop((left, top, right, bottom))
    if (verbose > 1): print('image.size = ', image.size, ' (after crop())')
    image = image.convert('L')
    pixeldata = np.array(image).astype(np.float32)

    lp = []
    for y in range(fixed_h):
        for x in range(fixed_w):
            ip = int(pixeldata[y][x])
            bp = ip.to_bytes(1, byteorder='big')
            outimagef.write(bp)
            lp.append(float(ip))
    image.close()

    if (verbose > 1):
        ap = np.array(lp)
        plt.imshow(ap.reshape([28, 28]), cmap='gray')
        plt.axis('off')
        plt.show()

outimagef.close()

#
#create label dataset
#
print('create label dataset start')
labelf = open(labelpathfile, 'r')
outlabelf = open(outlabelpathfile, 'wb')
outlabelf.write(bdataid)
outlabelf.write(bdatanum)

label = labelf.readline()
while label:
    label = label.rstrip('\n')
    label = int(label)
    if (verbose > 1): print('label = ', label)
    blabel = label.to_bytes(1, byteorder='big')
    outlabelf.write(blabel)
    label = labelf.readline()

labelf.close()
outlabelf.close()

print('create complete')
tmp/numbers/imageflist.txt
n00.jpg
n01.jpg
n02.jpg
n03.jpg
n04.jpg
n05.jpg
n06.jpg
n07.jpg
n08.jpg
n09.jpg
n10.jpg
n11.jpg
n12.jpg
n13.jpg
n14.jpg
n15.jpg
n16.jpg
n17.jpg
n18.jpg
n19.jpg
tmp/numbers/label.txt
0
1
2
3
4
5
6
7
8
9
0
1
2
3
4
5
6
7
8
9
0
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
0
Help us understand the problem. What is going on with this article?