pytorch自分で学ぼうとしたけど色々躓いたのでまとめました。具体的にはpytorch tutorialの一部をGW中に翻訳・若干改良しました。この通りになめて行けば短時間で基本的なことはできるようになると思います。躓いた人、自分で書きながら勉強したい人向けに各章末にまとめのコードのリンクがあるのでよしなにご活用ください。
pytorchの特徴
- chainerに近い構文構造
- defined by run: コードを走らせた際に順伝播に合わせて逆伝播のグラフが作成される
- 比較的自由に書いても怒られない(主観)
- 動的に構造が変化するネットワークを書きやすい
- バッチ処理前提(DataLoaderの概念)
- 実行速度が早い (cf deepPoseのpytorch実装)
importモジュール群
import torch #基本モジュール
from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torchvision import datasets, models, transforms #画像用データセット諸々
ざっくりpytorch
pytorchの基本
- pytorchはTensorという型で演算を行う
x = torch.Tensor(5, 3) #5x3のTensorの定義
y = torch.rand(5, 3) #5x3の乱数で初期化したTensorの定義
z = x + y #普通に演算も可能
- pytorchを利用するには全ての変数はTensorに変換する必要がある
- Tensor -> numpy :(Tensorの変数).numpy()
- numpy -> Tensor :torch.from_numpy(Numpyの変数)
x = np.random.rand(5, 3)
y = torch.from_numpy(x)
z = y.numpy()
- Tensor型にしたものを微分対応させるためにはさらにVariable型にする必要がある
- Variable使用時に自動微分を対応させるには独自関数を使用しないといけない
x = torch.rand(5, 3)
y = Variable(x)
z = torch.pow(y,2) + 2 #y_i**2 + 2
データの取得
pytorchで与えるデータはTensor(train), Tensor(target)のようにする必要がある。データラベルを同時に変換させる関数としてTensorDatasetがある。pytorchのDataLoaderはバッチ処理のみ対応している。
train = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)
test = torch.utils.data.TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
test_loader = torch.utils.data.DataLoader(test, batch_size=100, shuffle=True)
shuffle=Trueなどとすることであらかじめシャッフルしておくことも可能
また、torchvisionを使って画像処理関連処理(transformでTensor化、標準化、クロップなどの処理を一律に実行可能)、CIFAR-10などのデータセットの読み込みができる
#画像の変形処理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#CIFAR-10のTensorに変形前のtrainsetのロード
rawtrainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True)
#CIFAR-10のtrain, testsetのロード
#変形はtransformを適用
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
#DataLoaderの適用->これによりバッチの割り当て・シャッフルをまとめて行うことができる
#batch_sizeでバッチサイズを指定
#num_workersでいくつのコアでデータをロードするか指定(デフォルトはメインのみ)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
モデル定義
以下の例のようにモデルをクラス定義する必要がある
- initとforwardを定義すればOK
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784,500)
self.fc2 = nn.Linear(500, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x))
return x
#モデル定義
model = Net()
学習
Loss関数・optimizerの設定を下記のように行う
#Loss関数の指定
criterion = nn.CrossEntropyLoss()
#Optimizerの指定
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
この元でトレーニングを行う。けっこうどれも似たような流れになると思うので参考までに載せます
#トレーニング
#エポック数の指定
for epoch in range(2): # loop over the dataset multiple times
#データ全てのトータルロス
running_loss = 0.0
for i, data in enumerate(trainloader):
#入力データ・ラベルに分割
# get the inputs
inputs, labels = data
# Variableに変形
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# optimizerの初期化
# zero the parameter gradients
optimizer.zero_grad()
#一連の流れ
# forward + backward + optimize
outputs = net(inputs)
#ここでラベルデータに対するCross-Entropyがとられる
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# ロスの表示
# print statistics
running_loss += loss.data[0]
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
ここまで、以下にコードも置いておくので参考にしてください
pytorchのチュートリアル
モデルの転移と保存
モデルの読み込み
resnetなどがデフォルトで入っているので読み込むことができる
#resnetからのモデルの読み込み
model_ft = models.resnet18(pretrained=True)
###Fine Tuning
モデルをフリーズしたり、最終層を書き換えることもできる
#プレトレーニング済のモデルのフリーズ
for param in model_conv.parameters():
param.requires_grad = False
# モデルの最終層のみパラメータを書き換える
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
###モデルの保存と読み込み
モデルの保存は
torch.save(model.state_dict(), 'model.pth')
モデルの読み込みは
param = torch.load('model.pth')
model = Net() #読み込む前にクラス宣言が必要
model.load_state_dict(param)
転移学習関連のサンプルコードは以下
pytorchの転移学習
カスタマイズ
自作活性化関数
下記のようにクラス定義して、forwardとbackwordについて書いてあげれば良い
class MyReLU(torch.autograd.Function):
#forwardの活性化関数とbackwardの計算のみ記述すれば良い
def forward(self, input):
#値の記憶
self.save_for_backward(input)
#ReLUの定義部分
#x.clamp(min=0) <=> max(x, 0)
return input.clamp(min=0)
#backpropagationの記述
#勾配情報を返せば良い
def backward(self, grad_output):
#記憶したTensorの呼び出し
input, = self.saved_tensors
#参照渡しにならないようコピー
grad_input = grad_output.clone()
#input<0 => 0 else input
grad_input[input < 0] = 0
return grad_input
Loss関数の自作
Loss関数についてもクラス定義してinitとforwardを書けば動きます
class TripletMarginLoss(nn.Module):
def __init__(self, margin):
super(TripletMarginLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
dist = torch.sum(
torch.pow((anchor - positive),2) - torch.pow((anchor - negative),2),
dim=1) + self.margin
dist_hinge = torch.clamp(dist, min=0.0) #max(dist, 0.0)と等価
loss = torch.mean(dist_hinge)
return loss
ダイナミックなネットワーク
defined-by-run形式なので、forward部分に条件分岐などしてレイヤーを組み替えることもできる
class DynamicNet(torch.nn.Module):
#層の定義
def __init__(self, D_in, H, D_out):
super(DynamicNet, self).__init__()
self.input_linear = torch.nn.Linear(D_in, H)
self.middle_linear = torch.nn.Linear(H, H)
self.output_linear = torch.nn.Linear(H, D_out)
#ランダムに中間層を0~3に変更する
def forward(self, x):
h_relu = self.input_linear(x).clamp(min=0)
for _ in range(random.randint(0, 3)):
h_relu = self.middle_linear(h_relu).clamp(min=0)
y_pred = self.output_linear(h_relu)
return y_pred
ちゃんとしたコードは下記
pytorchのカスタマイズ
pytorch凄い使いやすいですね。日本だとchainer派ばっかりみたいですが、とても書きやすいし動作早いし便利だと思うんですよね。みんなも試してみては。
参考にしたサイトたち
-
pytorch tutorial
http://pytorch.org/tutorials/
これが一番分かりやすい。ただ各著者が各々書いているので重複してたりしていなかったり -
pytorch超入門
http://qiita.com/miyamotok0105/items/1fd1d5c3532b174720cd
pytorchの背景からベースのとこまで色々書いてある -
実践Pytorch
http://qiita.com/perrying/items/857df46bb6cdc3047bd8
もう一つの日本語pytorch入門。割とさっくりまとまってる -
PyTorch : Tutorial 日本語訳
http://caffe.classcat.com/2017/04/14/pytorch-tutorial-tensor/
pytorch tutorialの日本語訳。機械翻訳か?これという読みづらさ -
PyTorchでDeepPoseを実装してみた
http://qiita.com/ynaka81/items/85659dff4d1c2c593f21
deepPoseのpytorch実装 -
tripletLoss関数
http://docs.chainer.org/en/stable/_modules/chainer/functions/loss/triplet.html
ChainerのtripetLoss。pytorchでもほぼ同じ記法で書ける