はじめに
KaggleでGANを使用するコンペがつい先日始まりました。
自分は画像系の機械学習をほとんど触れてこなかったので、
KaggleのKernelをベースに、PyTorchのモジュールについて勉強がてら少し調べてみました。
モジュール
torch.nn
ドキュメント: https://pytorch.org/docs/stable/nn.html
PyTorchのニューラルネットワークモジュールの大元のクラス。
torch.nn.Module
モデルを作るときに必要な基本のクラス、
クラスで定義し、インスタンスでモデルを定義しているようです。
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
torch.nn.functional
ドキュメント: https://pytorch.org/docs/stable/nn.html#torch-nn-functional
forward関数を定義するときに使われている。
torch.nn.Moduleで定義したModelクラス内で定義されるようです。
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
torch.optim
ドキュメント: https://pytorch.org/docs/stable/optim.html
その名の通り最適化アルゴリズムが実装されたパッケージです。
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
torch.randn
ドキュメント: https://pytorch.org/docs/stable/torch.html#torch.randn
乱数のテンソルを返すようです。
ある例ではノイズの作成に活用されていました。
noise = torch.randn(batch_size, nz, 1, 1, device=device)
torch.full
ドキュメント: https://pytorch.org/docs/stable/torch.html#torch.full
第二引数で埋まった第一引数のサイズのテンソルを作成するようです。
labels = torch.full((batch_size, 1), real_label, device=device)
終わりに
GANのコンペが始まったということで、これを機にGANについて学んでみたいと思います。