LoginSignup
21
22

More than 5 years have passed since last update.

PyTorch(TorchVision)の便利なデータや変形機能をTensorFlow/Kerasでも使う

Last updated at Posted at 2019-03-21

PyTorch(TorchVision)ってデータセットがいっぱいあって便利ですよね。ものすごい羨ましい。

torchvision_01.png

「TensorFlow/KerasでもTorchVisionを使いたい!」と思ったのでやってみました。できました

Kerasの裏でPyTorchを動かすという変態的手法ですが、TorchVisionは便利すぎるのでどこかで使えるかもしれません。

これでできる

まずは結論から。TorchVisionをKerasから使うには、torchvision.transforms.Lambdaを使ってテンソル化します。

    transform = torchvision.transforms.Compose([
        # TensorFlowはChannelLastなのでTorchVisionのToTensorが使えない)
        torchvision.transforms.Lambda(lambda pic: np.asarray(pic, dtype=np.float32) / 255.0)
    ])

transformを定義したら、dataset, DataLoaderを定義します。

    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

ジェネレーターを独自に定義します。そしてPyTorchのDataLoaderをこのジェネレーターの引数に与え、forループからNumpy配列を取得します。次のように作ります。

def torchvision_generator(dataloader, n_classes):
    while True:
        for (X, y) in dataloader:
            y_onehot = np.identity(n_classes)[y.numpy()] # onehot化 
            yield X.numpy(), y_onehot

このジェネレーターをKerasのfit_generatorに食わせればOKです。

PyTorchとTensorFlowのChannelの扱いの違い

TorchVisionをTensorFlow/Kerasで使う際に一つ注意しておかなければいけないことがあります。それは、PyTorchがchannel-firstであるのに対して、TensorFlowはchannel-lastだからです。

channel-firstとは、画像のバッチが次のような4階テンソルで定義されることです。PyTorchやKerasでもMXNet-Backendがこれにあたります。

$$(Batch, Channel, Height, Width) $$

channel-firstではchannelは2番目にきます。一方で、channel-lastとは次のような4階テンソルで定義されることです。TensorFlowやTensorFlow-BackendのKerasではこれになります。

$$(Batch, Height, Width, Channel) $$

channel-lastではchannelは4番目にきます。まずはこういう違いがあるというのを認識します。

何が困るかというと、当然TorchVisionはPyTorchのためのライブラリであるので、TorchVisionのテンソル化をする操作、ToTensor()というのは、channel-firstで返ってきます。これをTensorFlow/Kerasで使うにはchannel-lastで欲しいのです。

channel-firstからchannel-lastの変換は、ToTensor()してからnp.transpose等で軸を入れ替えても良いのですが、それでは遅くなってしまったりメモリ効率が悪くなるはずなので、ToTensor()の部分をchannel-lastになるように自分で定義してしまえば良いのです。

幸いTorchVisionには独自の関数をラップするような変形が用意されています。torchvision.transforms.Lambdaという関数です(ドキュメント)。使い方はKerasのLambdaレイヤーと同じような感覚ですね。

では、TorchVisionのToTensorは何をやっているのかというと、ソースコードを読んでみると、どうもPILのインスタンス(Numpy配列のこともあるそうです)をNumpy配列に変換しているだけ、channel-firstへの変換も内部でnp.transposeを動かしているだけでした。それならPIL→Numpyの変換をnp.transposeを抜いてやればchannle-lastになるんじゃね?ということです。試してみましょう。

PyTorchのDataLoaderからCIFAR-10をプロットする

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.Lambda(lambda pic: np.asarray(pic, dtype=np.uint8))
    ])
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False)

for (X, y) in trainloader:
    X, y = X.numpy(), y.numpy()
    n = X.shape[0]
    print(X.shape)
    print(y)

    for i in range(n):
        ax = plt.subplot(2,2,i+1)
        ax.imshow(X[i])
    plt.show()

Lambdaの前に`RandomHorizontalFlip(ランダムな水平反転)を入れてみました。見た目としては、ToTensor()をLambdaに置き換えただけです。Lambdaの部分は0~1化する必要がなさそうなのでそのままuint8としました。画像の変形をtransformを使って、パイプラインで書けるのが美しいですね。

.numpy()を使ってPyTorchのTensorからNumpy配列に変形しているのが若干冗長ですが、ここのオーバーヘッドはあまり大きくなかったです(あとで見ます)。

2回動かしてみました。RandomHorizontalFlipが効いているのが理解できるしょうか。

torchvision_02.png
torchvision_03.png

X, yの中身を確認すると次のようになります。

(4, 32, 32, 3) # X.shape
[6 9 9 4] # y

無事channel-lastになっていることが確認できますね。yはそのままラベルの配列になっているので、あとでone-hot化すればよいでしょう。

TorchVisionをtransforms.Lambdaを使ってchannel-last化するだけではなく、TorchVision特有の変形操作も使えるというのが確認できました。

TorchVisionをKerasで使ったCIFAR-10分類

KerasからTorchVisionを呼んでCIFAR-10を分類してみましょう。

from keras import layers
from keras.models import Model
from keras.optimizers import SGD
import torch
import torchvision
import numpy as np

def create_block(inputs, ch, rep):
    x = inputs
    for i in range(rep):
        x = layers.Conv2D(ch, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
    return x

def create_model():
    input = layers.Input((32, 32, 3))
    x = create_block(input, 64, 3)
    x = layers.AveragePooling2D(2)(x)
    x = create_block(x, 128, 3)
    x = layers.AveragePooling2D(2)(x)
    x = create_block(x, 256, 3)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(10, activation="softmax")(x)
    return Model(input, x)

def torchvision_generator(dataloader, n_classes):
    while True:
        for (X, y) in dataloader:
            y_onehot = np.identity(n_classes)[y.numpy()] # onehot化 
            yield X.numpy(), y_onehot

def train():
    # TorchVisionのtransform 
    transform = torchvision.transforms.Compose([
        # TensorFlowはChannelLastなのでTorchVisionのToTensorが使えない)
        torchvision.transforms.Lambda(lambda pic: np.asarray(pic, dtype=np.float32) / 255.0)
    ])
    # Cifarの訓練、テスト
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    # dataloader
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

    # モデル
    model = create_model()
    model.compile(SGD(0.1, 0.9, 1e-3), "categorical_crossentropy", ["acc"])

    # 訓練
    model.fit_generator(torchvision_generator(trainloader, 10), steps_per_epoch=50000//128,
                        validation_data=torchvision_generator(testloader, 10),
                        validation_steps=10000//128, epochs=100)

if __name__ == "__main__":
    train()

「import keras…」と「import torch」が並んでいるのがいびつで仕方ないですが、これで動きます。

TorchVision+Keras
Using TensorFlow backend.
Files already downloaded and verified
Files already downloaded and verified
: : :
Epoch 1/100
390/390 [==============================] - 52s 134ms/step - loss: 1.4592 - acc: 0.4634 - val_loss: 1.6707 - val_acc: 0.4214
Epoch 2/100
390/390 [==============================] - 49s 127ms/step - loss: 0.9258 - acc: 0.6718 - val_loss: 0.9590 - val_acc: 0.6708
Epoch 3/100
390/390 [==============================] - 49s 126ms/step - loss: 0.6704 - acc: 0.7644 - val_loss: 0.9691 - val_acc: 0.6831
Epoch 4/100
390/390 [==============================] - 49s 126ms/step - loss: 0.5337 - acc: 0.8145 - val_loss: 0.6188 - val_acc: 0.7845
Epoch 5/100
390/390 [==============================] - 49s 126ms/step - loss: 0.4431 - acc: 0.8447 - val_loss: 0.6425 - val_acc: 0.7860

またKerasではCIFAR-10が用意されているので、全く同じネットワークでImageDataGeneratorを使って分類してみる(よくあるコード)と次のようになります。

普通のKeras
Epoch 1/100
390/390 [==============================] - 53s 135ms/step - loss: 1.4727 - acc: 0.4564 - val_loss: 1.5668 - val_acc: 0.4713
Epoch 2/100
390/390 [==============================] - 48s 124ms/step - loss: 0.9180 - acc: 0.6702 - val_loss: 1.2747 - val_acc: 0.5846
Epoch 3/100
390/390 [==============================] - 48s 123ms/step - loss: 0.6859 - acc: 0.7566 - val_loss: 0.9547 - val_acc: 0.6810
Epoch 4/100
390/390 [==============================] - 48s 123ms/step - loss: 0.5458 - acc: 0.8102 - val_loss: 0.8640 - val_acc: 0.7093
Epoch 5/100
390/390 [==============================] - 48s 123ms/step - loss: 0.4580 - acc: 0.8396 - val_loss: 0.6343 - val_acc: 0.7817

普通のKerasのほうが1ステップあたり数ms速いですね。これはTorchVisionの例がNumpy→Torchテンソル→Numpyという無駄な処理をやっているから仕方ないと思います。ただし、それでもオーバーヘッドが数ms程度ということです。もしこの無駄な処理を削る方法をご存知でしたらぜひ教えてください。

なお、精度面では特に目立った違いは確認できませんでした。

メリット

KerasからTorchVisionを使うのはどういうメリットがあるかというと、2点あります。

  1. Kerasの組み込みでは用意されていないが、TorchVisionに組み込んであるデータを読み込むのが簡単になる(例:K-MNIST、SVHN、STL10など)
  2. ImageDataGeneratorに組み込まれていなく、独自に定義する必要のあるData Augmentationの処理をtorchvision.transformsから使える。Numpyからいちいち書かなくてよくなるかもしれない

2点目はすべて確認したわけではないので、もしかしたら使えない変換もあるかもしれません。

PyTorchにはJPEGを高速に読み込めるaccimageというライブラリがついているので、もしかしたらaccimageも同様の方法で使えるかもしれません。そうしたらすごいですよね。興味のある方はぜひやってみてください。

お知らせ

4/14に開催される「技術書典6」にて、なのなのさんのブース「か72」 N4+」で売り子兼頒布をする予定です。

DeepCreamPy(ディープラーニングを使ったモザイク除去)からディープラーニングを学ぼう」というような内容を予定しております。現在鋭意執筆中です。3/21現在本ができていない(半分ぐらいは書けました)ので、サークル情報にはまだ反映されていません。

注意点

  • 私のブースではありません、なのなのさんのブースです
  • 位置的に機械学習・ディープラーニングの島ではありません。なのなのさんがVue.js+Electronの本を頒布するのでそちらの島の配置されています。
  • モザイク除去ですが全年齢版です(ここ重要)

お楽しみに!

(間に合わなかったらごめんなさい)

21
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
22