LoginSignup
2
2

More than 3 years have passed since last update.

GANで使われていたPyTorchのモジュールについて調べてみた

Posted at

はじめに

KaggleでGANを使用するコンペがつい先日始まりました。

自分は画像系の機械学習をほとんど触れてこなかったので、
KaggleのKernelをベースに、PyTorchのモジュールについて勉強がてら少し調べてみました。

モジュール

torch.nn

ドキュメント: https://pytorch.org/docs/stable/nn.html

PyTorchのニューラルネットワークモジュールの大元のクラス。

torch.nn.Module

モデルを作るときに必要な基本のクラス、
クラスで定義し、インスタンスでモデルを定義しているようです。

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

torch.nn.functional

ドキュメント: https://pytorch.org/docs/stable/nn.html#torch-nn-functional

forward関数を定義するときに使われている。
torch.nn.Moduleで定義したModelクラス内で定義されるようです。


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

torch.optim

ドキュメント: https://pytorch.org/docs/stable/optim.html

その名の通り最適化アルゴリズムが実装されたパッケージです。

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

torch.randn

ドキュメント: https://pytorch.org/docs/stable/torch.html#torch.randn

乱数のテンソルを返すようです。
ある例ではノイズの作成に活用されていました。

noise = torch.randn(batch_size, nz, 1, 1, device=device)

torch.full

ドキュメント: https://pytorch.org/docs/stable/torch.html#torch.full

第二引数で埋まった第一引数のサイズのテンソルを作成するようです。

labels = torch.full((batch_size, 1), real_label, device=device)

終わりに

GANのコンペが始まったということで、これを機にGANについて学んでみたいと思います。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2