LoginSignup
61
37

More than 5 years have passed since last update.

pytorchによる画像分類入門

Last updated at Posted at 2018-05-05

研究で画像識別系(深層学習)を扱うことになり,せっかくやるなら未来なのありそうなフレームワークを勉強しようと調べていたところpytorchに行き着きました.
海外のコミュニティが中心で,研究にも頻繁に利用されているナウなフレームワークという印象です.もうすぐver1.0がリリースされるみたいですね.

早速勉強を開始したのですが,MNISTやCifer10などの一般的なデータセットはすでにPytorch内に用意されているため,自分の研究用のデータやKaggle中のデータセットからどう学習用のデータセットを作るのか全くわかりませんでした(日本語少ない).

ですので今回はKaggleで見つけたカラー手書き文字の識別を行なった過程を,学習データの作成に重きを置いて解説したいと思います(https://www.kaggle.com/olgabelitskaya/classification-of-handwritten-letters).

自分は深層学習も機械学習もpytorchもトウシローですので良いアウトプットの仕方など教えていただければ大変に幸いです.

  • python 3.6
  • pytorch 0.3.1

データセットについて

Kaggleのカーネルより持ってきました.

手書きされたロシア語の33文字を識別するというものです.
画像データへのパス,ラベルなどが書かれたletters2.csvの構成は次のようになっています.

letter2.csv
文字, 正解ラベル, 画像へのパス, 文字画像の背景のタイプ

今回は画像の背景は気にせずに正解ラベルと画像へのパスのみを用いました.

学習用データの作成

以下を参考に行いました.
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
非常に形式ばったやり方ですのでもっと簡単な方法はありますが,私はこの形式が後々で便利だと思いました.

データの読み出し

まずデータセットの作り方です.コードは以下のようになります.

from torch.utils.data import Dataset
import os
import sys

LABEL_IDX = 1
IMG_IDX = 2


class MyDataset(Dataset):

    def __init__(self, csv_file_path, root_dir, transform=None):
        #pandasでcsvデータの読み出し
        self.image_dataframe = pd.read_csv(csv_file_path)
        self.root_dir = root_dir
        #画像データへの処理
        self.transform = transform

    def __len__(self):
        return len(self.image_dataframe)

    def __getitem__(self, idx):
        #dataframeから画像へのパスとラベルを読み出す
        label = self.image_dataframe.iat[idx, LABEL_IDX]
        img_name = os.path.join(self.root_dir, 'classification-of-handwritten-letters',
                'letters2', self.image_dataframe.iat[idx, IMG_IDX])
        #画像の読み込み
        image = io.imread(img_name)
        #画像へ処理を加える
        if self.transform:
            image = self.transform(image)

        return image, label

Datasetクラスを継承して自分のデータセットを作ります.コンストラクタではデータ(letters2.csv)へのパス,コードへの絶対パス,画像へ加えたい処理を与えます.transformについては後ほど解説します.

getitemメソッドは,データセットへのインデックスアクセスが行われた際に呼び出されます.dataframeを元に画像とラベルを読み込み,処理を加えて返すという形です.
チュートリアルでは辞書型で返していましたが,後の処理が面倒だったのでsetで返しています.

このMyDatasetクラスをインスタンス化します.

from torchvision import transforms

class MyNormalize:

    def __call__(self, image):
        shape = image.shape
        image = (image - np.mean(image))/np.std(image)*16+64
        return image


imgDataset = MyDataset(input_file_path, ROOT_DIR, transform=transforms.Compose([
    transforms.ToTensor(),
    MyNormalize()
    ]))

インスタンス化の際に画像へ加える処理Transformを与えています.ToTensorはnumpy形式のデータをpytorchでの計算に用いるtensor型へ変換する役割があります.Transformは自分で定義することもできます,MyNormalizeがそれです(本当は正規化処理はライブラリにあります).callメソッドの中にデータに加えたい処理を書きます.Transformを複数与える際は,Composeを用います.標準で用意されているTransformを用いる際はデータ形式をPIL Imageにする必要があります.

これでインデックスアクセス可能なデータセットが出来上がります.

学習データ,テストデータへの分割

scikit-learnのtrain_test_splitterを用いました.上のようにデータセットを作っておくことでそのまま分割できます.

train_data, test_data = train_test_split(imgDataset, test_size=0.2)

pytorchではバッチ処理が基本のようです.なので分割したデータをバッチ処理用のDataLoader型にします.

train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=True)

データセットとバッチサイズなどを与えるだけです.シャッフルもしてくれます.
これでデエプラアニングでの学習の準備が完了しました.

CNNの構築と学習

ほとんどpytorchのexamples中のMNISTチュートリアルからのコピペです.
https://github.com/pytorch/examples/tree/master/mnist
私は雰囲気でDeepLearningをしています.

ネットワークの構築

以下のようなネットワーク(CNN)を構築しました.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(4, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(500, 100)
        self.fc2 = nn.Linear(100, 34)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 500)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

今回利用する画像は4チャネルですので最初のconvolution層への入力サイズは4となっています.また,最後のFC層からの出力は識別クラス数+1の34となっています.なんで34やねんと思っていたのですが,どうやらpytorchが0からクラスを数えているからのようです.隠れ層のサイズは適当ですが,とりあえずこれで動きました.私は雰(ry

モデルのトレーニングとテスト

こちらもほぼコピペです.わ(ry


model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()

def train(epoch):
    model.train()
    for batch_idx, (image, label) in enumerate(train_loader):
        image, label = Variable(image), Variable(label)
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(image), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.data[0]))


def test():
    model.eval()

    for (image, label) in test_loader:
        image, label = Variable(image.float(), volatile=True), Variable(label)
        output = model(image)
        test_loss += criterion(output, label).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(label.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

#学習
for epoch in range(1, 1000 + 1):
    train(epoch)
    test()

DataLoaderが勝手にバッチサイズごとにデータを回してくれるので,それをモデルに与えて学習させます.モデルにデータを与える前にVariable型に変換してますが,これは各データの微分を計算できるようにするためです.テストの際は結果を蓄積する必要(backpropしない)がないためvolatileをTrueにします.
予測の際に必要でない操作(Dropoutなど)を無効にするため,トレーニングを行う場合とテストの場合でモデルの状態を切り替えます.
学習の際のTensorの型はデータがFloatTensor,ラベルがLongTensorとするのが一般的だそうなので,エラーを吐かれた際は明示的に指定しましょう.

学習結果

MNISTを10epoch程度の学習で95%当てられていたので同じようにやっていたのですが,何回やっても2%くらいしか精度が出ず,何日か頭を抱えていました.しかし1000 epochで行なったところ,とりあえず81%程度の精度は出すことができました.普通の機械学習手法の方が全然精度いいじゃん,と言われそうですが,これは私が深層学習も機械学習も全く理解していないためです.パラメータも全て適当です.単なる文字の識別でもモデルに突っ込むだけではダメだということがわかりました.頑張って勉強しようと思います.

まとめ

  • pytorchのモジュールを用いて学習データの作成を行なった
  • カラーMNISTの識別pytorchで構築したCNNを用いて行なった
  • 適当な設定では微妙な精度しか出ないことがわかった

ツッコミどころ満載だとは思いますが,ぜひ言葉としてアドバイスいただければ幸いです.

参考

61
37
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
61
37