Edited at

ライトニングpytorch入門

More than 1 year has passed since last update.

pytorch自分で学ぼうとしたけど色々躓いたのでまとめました。具体的にはpytorch tutorialの一部をGW中に翻訳・若干改良しました。この通りになめて行けば短時間で基本的なことはできるようになると思います。躓いた人、自分で書きながら勉強したい人向けに各章末にまとめのコードのリンクがあるのでよしなにご活用ください。


pytorchの特徴


  • chainerに近い構文構造

  • defined by run: コードを走らせた際に順伝播に合わせて逆伝播のグラフが作成される

  • 比較的自由に書いても怒られない(主観)

  • 動的に構造が変化するネットワークを書きやすい

  • バッチ処理前提(DataLoaderの概念)

  • 実行速度が早い (cf deepPoseのpytorch実装)


importモジュール群

import torch #基本モジュール

from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torchvision import datasets, models, transforms #画像用データセット諸々


ざっくりpytorch


pytorchの基本


  • pytorchはTensorという型で演算を行う

x = torch.Tensor(5, 3) #5x3のTensorの定義

y = torch.rand(5, 3) #5x3の乱数で初期化したTensorの定義
z = x + y #普通に演算も可能


  • pytorchを利用するには全ての変数はTensorに変換する必要がある


    • Tensor -> numpy :(Tensorの変数).numpy()

    • numpy -> Tensor :torch.from_numpy(Numpyの変数)



x = np.random.rand(5, 3)

y = torch.from_numpy(x)
z = y.numpy()


  • Tensor型にしたものを微分対応させるためにはさらにVariable型にする必要がある


    • Variable使用時に自動微分を対応させるには独自関数を使用しないといけない



x = torch.rand(5, 3)

y = Variable(x)
z = torch.pow(y,2) + 2 #y_i**2 + 2


データの取得

pytorchで与えるデータはTensor(train), Tensor(target)のようにする必要がある。データラベルを同時に変換させる関数としてTensorDatasetがある。pytorchのDataLoaderはバッチ処理のみ対応している。

train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))

train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)
test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
test_loader = torch.utils.data.DataLoader(test, batch_size=100, shuffle=True)

shuffle=Trueなどとすることであらかじめシャッフルしておくことも可能

また、torchvisionを使って画像処理関連処理(transformでTensor化、標準化、クロップなどの処理を一律に実行可能)、CIFAR-10などのデータセットの読み込みができる

#画像の変形処理

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#CIFAR-10のTensorに変形前のtrainsetのロード
rawtrainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True)

#CIFAR-10のtrain, testsetのロード
#変形はtransformを適用
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)

#DataLoaderの適用->これによりバッチの割り当て・シャッフルをまとめて行うことができる
#batch_sizeでバッチサイズを指定
#num_workersでいくつのコアでデータをロードするか指定(デフォルトはメインのみ)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)


モデル定義

以下の例のようにモデルをクラス定義する必要がある

- initとforwardを定義すればOK

class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784,500)
self.fc2 = nn.Linear(500, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x))
return x

#モデル定義
model = Net()


学習

Loss関数・optimizerの設定を下記のように行う

#Loss関数の指定

criterion = nn.CrossEntropyLoss()

#Optimizerの指定
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

この元でトレーニングを行う。けっこうどれも似たような流れになると思うので参考までに載せます

#トレーニング

#エポック数の指定
for epoch in range(2): # loop over the dataset multiple times

#データ全てのトータルロス
running_loss = 0.0

for i, data in enumerate(trainloader):

#入力データ・ラベルに分割
# get the inputs
inputs, labels = data

# Variableに変形
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)

# optimizerの初期化
# zero the parameter gradients
optimizer.zero_grad()

#一連の流れ
# forward + backward + optimize
outputs = net(inputs)

#ここでラベルデータに対するCross-Entropyがとられる
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# ロスの表示
# print statistics
running_loss += loss.data[0]
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

print('Finished Training')

ここまで、以下にコードも置いておくので参考にしてください

pytorchのチュートリアル


モデルの転移と保存


モデルの読み込み

resnetなどがデフォルトで入っているので読み込むことができる

#resnetからのモデルの読み込み

model_ft = models.resnet18(pretrained=True)


Fine Tuning

モデルをフリーズしたり、最終層を書き換えることもできる

#プレトレーニング済のモデルのフリーズ

for param in model_conv.parameters():
param.requires_grad = False

# モデルの最終層のみパラメータを書き換える
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)


モデルの保存と読み込み

モデルの保存は

torch.save(model.state_dict(), 'model.pth')

モデルの読み込みは

param = torch.load('model.pth')

model = Net() #読み込む前にクラス宣言が必要
model.load_state_dict(param)

転移学習関連のサンプルコードは以下

pytorchの転移学習


カスタマイズ


自作活性化関数

下記のようにクラス定義して、forwardとbackwordについて書いてあげれば良い

class MyReLU(torch.autograd.Function):

#forwardの活性化関数とbackwardの計算のみ記述すれば良い
def forward(self, input):

#値の記憶
self.save_for_backward(input)

#ReLUの定義部分
#x.clamp(min=0) <=> max(x, 0)
return input.clamp(min=0)

#backpropagationの記述
#勾配情報を返せば良い
def backward(self, grad_output):

#記憶したTensorの呼び出し
input, = self.saved_tensors

#参照渡しにならないようコピー
grad_input = grad_output.clone()

#input<0 => 0 else input
grad_input[input < 0] = 0
return grad_input


Loss関数の自作

Loss関数についてもクラス定義してinitとforwardを書けば動きます

class TripletMarginLoss(nn.Module):

def __init__(self, margin):
super(TripletMarginLoss, self).__init__()
self.margin = margin

def forward(self, anchor, positive, negative):
dist = torch.sum(
torch.pow((anchor - positive),2) - torch.pow((anchor - negative),2),
dim=1) + self.margin
dist_hinge = torch.clamp(dist, min=0.0) #max(dist, 0.0)と等価
loss = torch.mean(dist_hinge)
return loss


ダイナミックなネットワーク

defined-by-run形式なので、forward部分に条件分岐などしてレイヤーを組み替えることもできる

class DynamicNet(torch.nn.Module):

#層の定義
def __init__(self, D_in, H, D_out):
super(DynamicNet, self).__init__()
self.input_linear = torch.nn.Linear(D_in, H)
self.middle_linear = torch.nn.Linear(H, H)
self.output_linear = torch.nn.Linear(H, D_out)

#ランダムに中間層を0~3に変更する
def forward(self, x):
h_relu = self.input_linear(x).clamp(min=0)
for _ in range(random.randint(0, 3)):
h_relu = self.middle_linear(h_relu).clamp(min=0)
y_pred = self.output_linear(h_relu)
return y_pred

ちゃんとしたコードは下記

pytorchのカスタマイズ




pytorch凄い使いやすいですね。日本だとchainer派ばっかりみたいですが、とても書きやすいし動作早いし便利だと思うんですよね。みんなも試してみては。


参考にしたサイトたち