Help us understand the problem. What is going on with this article?

Kerasを勉強した後にPyTorchを勉強して躓いたこと

概要

DeppLearningのフレームワークで最初にKerasを勉強した後に、Define by RunのPyTorch勉強してみて躓いたポイントをまとめてみる。

この記事の対象読者

Kerasの次にPyTorchを勉強してみようと思っている人。

はじめに

今回いくつか挙げている躓いたポイントはPyTorchに限らないものがある。またKerasといえばバックエンドはTensorFlowのものを指す。バックエンドがTensorFlowでない場合は話が当てはまらないものもあるので注意。

今回挙げたポイントは以下の5つ
1. Channel First
2. GPUへの転送
3. CrossEntropyがSoftmax+CrossEntropyになっている
4. CrossEntropyがone-hot-vectorに対応していない
5. 学習と評価を区別する

以下、各ポイントの詳細について説明していく。

Channel First

PyTorchではモデルの入力と出力がChannel Firstの形式になっている。Channel Firstとは画像の次元の並びが(C, H, W)のようにChannelの次元が最初になっていること。
KerasではChannel Lastになっているため、(H, W, C)のようにChannelの次元が最後にくる。

実際にモデルに入力するときは、バッチサイズも合わせた4次元で表現する必要があるため、
PyTorch:(N, C, H, W)
Keras:(N, H, W, C)
となる。

記号の意味は
N:バッチサイズ
C:チャネル数
H:画像のHeight
W:画像のWidth

画像を読み込む際は、OpenCVかPILを使用する場合が多いが、これらのモジュールはChannel Lastで画像を扱う仕様になっている。なので、PyTorchのモデルに入力する前に以下のコードのようにChannel Firstに変換する必要がある。

img = cv2.imread(img_path)
img = img.transpose((2, 0, 1)) # H x W x C -> C x H x W

モデルの出力もChannel Firstなのでmatplotlibなどで表示したい場合はChannel Lastに変換してから表示する。

output = output.numpy() # tensor -> ndarray
output = output.transpose(1, 2, 0) # C x H x W -> H x W x C

GPUへの転送

KerasではGPUを使う場合、GPU側のメモリを意識することがなかったが、PyTorchではGPUを使用する場合、明示的に学習するパラメータや入力データをGPU側のメモリに転送しなければならない。
以下のコードではモデルと入力データをGPUに転送している。

device = torch.device("cuda:0")
# modelはnn.Moduleを継承したクラス
model = model.to(device) # GPUへ転送



for imgs, labels in train_loader:
    imgs, labels = imgs.to(device), labels.to(device) # GPUへ転送

GPU上にあるデータCPUに転送したい場合も以下のようにコードを書く必要がある。

device = torch.device("cpu")
model.to(device)

CrossEntropyがSoftmax+CrossEntropyになっている

Kerasで多クラスの識別モデルを学習するときは、モデルの最終層でsoftmaxを実行してからcategorical_crossentropyでロスを計算する流れになっている。
一方PyTorchではロス関数であるtorch.nn.CrossEntropyの中でSoftmaxの計算も一緒に行っているので、モデルの最終層でSoftmaxは不要になる。

たまにPyTorchのサンプルコードで最終層にtorch.nn.LogSoftmaxを置いて、ロス関数にtorch.nn.NLLLossを指定している場合がある。これは最終層を恒等関数にしてtorch.nn.CrossEntropyを使っているのと同じになる。
つまり、
torch.nn.CrossEntropy=torch.nn.LogSoftmaxtorch.nn.NLLLoss
という関係になっている。

torch.nn.LogSoftmaxは名前の通りSoftmaxの計算にLogをかぶせたものになっている。

LogSoftmax=log(\frac{e^{xj}}{\sum_{i=1}^{n} e^{xi}})

このLogはCrossEntropyの式にあるLogを持ってきているのだが、LogとSoftmaxを先に一緒に計算しておくことで、計算結果を安定させている。
なぜLog+Softmaxが計算的に安定するかは以下のページで解説されている。
Tricks of the Trade: LogSumExp

ちなみにtorch.nn.NLLLossはCrossEntropyのLogを抜いた他の計算を行っている。

CrossEntropyがone-hot-vectorに対応していない

Kerasではロスを計算するときに、labelはone-hot-vector形式で渡す必要があるがPyTorchでは正解の値をそのまま渡す。

例えば、3クラスの分類で正解が2番目のクラスの場合、Kerasでは[0, 1, 0]というリストをロス関数に渡すが、PyTorchでは2という値を渡す。

学習と評価を区別する

PyTorchでは、モデルを動作させるときに学習中なのか評価中なのかを明示的にコードで示す必要がある。なぜこれが必要なのかは理由が2つある。

1.学習中と評価中に挙動が変わるレイヤーがあるから
2.学習中には必要で評価中には不必要な計算があるから

1は、DropOutやBatchNormalizationなどのことで、これらのレイヤーは学習中と評価中で動作が変わる。よって、コードでこれから動作するのが学習なのか評価なのかを知らせる必要がある。
具体的には以下のようなコードになる。

# modelはnn.Moduleを継承したクラス
model.train() # 学習モードに遷移
model.eval() # 評価モードに遷移

2の不必要な計算とは計算グラフを作ることである。学習中は計算グラフを作って、誤差逆伝播法で誤差を計算グラフ上に伝播させて重みを更新する必要がある。しかし、学習以外の処理ではこの計算グラフの構築が不要になるので「計算グラフを作りません」とコードで示す必要がある。
具体的にはwith torch.no_grad()を使う。

model.eval() # 評価モードに遷移
with torch.no_grad(): # この中では計算グラフは作らない
    output = model(imgs)
Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away