PyTorch

PyTorchでのデータの形


データの形の不一致とは

クラスでネットワークを組んでいて、さあできあがって動かしてみようということで走らせてみたらデータの形が合ってなかった、ということは割とみんな経験があるのではないですかね?

当然、ぼくも素人なのでしょっちゅうやります(笑)

ということで、今回はメモ書き程度にそこらへんのポイントを改めて押さえておきます。


CNN

まずはCNN。

機械学習を始めるにあたって最初にやるタスクってみんなCNNによる画像分類ですよね(簡単だからかな?)

PyTorchでのCNNは使うとき

import torch

import torch.nn as nn

m = nn.Conv2d(input_channel, output_channel, kernel_size)

というような感じになりますよね。

このように最初のネットワークを設定するとき、最初の2つの引数が入力と出力の数(チャンネルの数)になります。

例えば、MNISTで手書き数字の画像分類をするとき、1枚の画像を入れて、8x8のフィルターを用いて、0から9までの10個の数字を分類するとき

m = nn.Conv2d(1, 10, 8)

というような設定をすることになります。

この層を設定して、実際に層に入力するときは以下のようになります。

m = nn.Conv2d(input_channel, output_channel, kernel_size)

input_image = torch.FloatTensor(batch_size, input_channel, image_size_height, image_size_width)

こうすると出力としては

output = m(input_image)

print(output.size())

-> torch.Size([batch_size, output_channel, output_size_height, output_size_width])

となります。

ここで、output_sizeはカーネルサイズやパディングをするかどうかで変わります。

入力のサイズは4次元であることに注意です。


Linear

画像分類でも最後の全結合層として使われる線型層ですね。

これはとても簡単で、

 m = nn.Linear(input_num, output_channel)

x = torch.randn(a, input_num)

y = torch.randn(a, b, c, input_num)

----------------------------------------------

m_x, m_y = m(x), m(y)

print(m_x.size())

-> torch.Size([a, output_num])

print(m_y.size())

-> torch.Size([a, b, c, output_num])

というように入力のサイズの最後の数字が変化しているのがわかりますよね?

a, b, cは任意です。

それ以外は変わりません。

なので、データの次元がどのようでも線型層を使うことができます。


LSTM

とても複雑なやつです。

ではデータの入力形式について以下を見ていきましょう。

m = nn.LSTM(input_num, hidden_num, num_layers)

x = torch.randn(a, b, input_num)

h = torch.randn(num_layers, b, hidden_num)

c = torch.randn(num_layers, b, hidden_num)

output, hidden, cell = m(x, (h, c))

-------------------------------------------------

print(output.size())

-> torch.Size([a, b, hidden_num])

print(hidden.size())

-> torch.Size([num_layers, b, hidden_num])

print(cell.size())

-> torch.Size([num_layers, b, hidden_num])

このようにデータのサイズは変化します。

先ほど挙げた解説記事を見ていればまあ理解はできますが、実際にデータサイズがどういう風に変化するかはもうこのように覚えておいていいかもしれません。


まとめ

以上ではCNNやLinear、LSTMなど有名どころを挙げました。

他にも色々あるのですが、とりあえず今回はこの程度にしておこうと思います。

上にあげたデータの入出力制御ができれば、サイズ関係でのエラーは回避できるようになると思うので、この記事が誰かの役に立てばなぁって感じです。

ではお疲れ様でした!