15
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのConv1dを理解する

Last updated at Posted at 2021-09-20

PyTorchバージョン:1.9.0

Conv1dについての公式説明

Conv1dのコンストラクターに指定しないといけないパラメータは順番に下記三つあります。

  1. 入力チャネル数(in_channels
  2. 出力チャネル数(out_channels
  3. カーネルサイズ(kernel_size

例えば、下記のソースコードは入力チャネル数2、出力チャネル数3、カーネルサイズ5のConv1dインスタンスを作成します。

from torch import nn

conv1d = nn.Conv1d(2,3,5)

オブジェクト作成できたら、初期化された重みとバイアスが確認できます。

>>> conv1d.weight.shape
torch.Size([3, 2, 5])
>>> conv1d.weight
Parameter containing:
tensor([[[ 0.2594, -0.2927,  0.3010, -0.3144, -0.0263],
         [ 0.1818,  0.1792, -0.1513,  0.0448,  0.2669]],

        [[-0.1189,  0.1470, -0.1873, -0.1977,  0.0357],
         [ 0.1807,  0.0479,  0.2231, -0.2369, -0.1685]],

        [[ 0.0283,  0.0707,  0.0137, -0.0436, -0.2092],
         [ 0.1842,  0.2262, -0.1358, -0.1469,  0.0953]]], requires_grad=True)
>>> conv1d.bias
Parameter containing:
tensor([ 0.1507, -0.0665,  0.2158], requires_grad=True)

作成されたインスタンスconv1dを使うために、畳み込みの入力配列を用意しましょう。入力配列は三次元でなければなりません。

  1. 一次元目のサイズ:バッチサイズbatch_size
  2. 二次元目のサイズ:入力チャネル数、conv1dインスタンスを作成したときに指定した入力チャネル数と一致しないといけません
  3. 三次元目のサイズ:一件のデータの長さ、例えば、音声の場合、時間軸のサンプル(フレーム)数が一般的でしょう
>>> import torch
>>> x = torch.rand(4,2,6)
>>> x
tensor([[[0.8598, 0.9945, 0.8397, 0.0875, 0.1347, 0.2212],
         [0.9039, 0.9663, 0.2980, 0.4002, 0.8641, 0.1295]],

        [[0.6044, 0.1435, 0.9415, 0.6749, 0.1406, 0.5504],
         [0.5129, 0.1664, 0.3843, 0.5065, 0.4144, 0.2583]],

        [[0.9566, 0.8054, 0.3213, 0.8039, 0.4228, 0.9182],
         [0.8417, 0.0937, 0.0542, 0.2004, 0.9569, 0.3480]],

        [[0.7783, 0.7377, 0.9412, 0.3135, 0.7974, 0.1117],
         [0.5227, 0.1919, 0.0875, 0.8341, 0.0841, 0.1266]]])

畳み込みの計算結果を確認しましょう。

>>> y = conv1d(x)
>>> y.shape
torch.Size([4, 3, 2])
>>> y
tensor([[[ 0.8454,  0.3828],
         [-0.1566, -0.0447],
         [ 0.6581,  0.3289]],

        [[ 0.5311,  0.1668],
         [-0.4254, -0.0599],
         [ 0.2421,  0.1869]],

        [[ 0.4219,  0.4825],
         [-0.3059, -0.5375],
         [ 0.4113, -0.0433]],

        [[ 0.4764, -0.1308],
         [-0.3490, -0.0444],
         [ 0.1357,  0.1910]]], grad_fn=<SqueezeBackward1>)

次に上記の計算結果をどう導き出すかを説明します。

入力配列xのバッチサイズは4です。即ち、四件のデータがあります。畳み込み演算はx[0]x[1]x[2]x[3]にそれぞれ行われて、y[0]y[1]y[2]y[3]が得られました。

conv1d.weight[0]conv1d.bias[0]y[:,0](出力チャネル0)に対応します。

>>> conv1d.weight[0]
tensor([[ 0.2594, -0.2927,  0.3010, -0.3144, -0.0263],
        [ 0.1818,  0.1792, -0.1513,  0.0448,  0.2669]],
       grad_fn=<SelectBackward>)
>>> conv1d.weight[0,0]
tensor([ 0.2594, -0.2927,  0.3010, -0.3144, -0.0263], grad_fn=<SelectBackward>)
>>> y[0,0]  # 0件目データに対して、チャネル0の出力
tensor([0.8454, 0.3828], grad_fn=<SelectBackward>)
>>> x[0,0,0:5].dot(conv1d.weight[0,0]) + x[0,1,0:5].dot(conv1d.weight[0,1]) + conv1d.bias[0]
tensor(0.8454, grad_fn=<AddBackward0>)
>>> x[0,0,1:6].dot(conv1d.weight[0,0]) + x[0,1,1:6].dot(conv1d.weight[0,1]) + conv1d.bias[0]
tensor(0.3828, grad_fn=<AddBackward0>)
>>> y[1,0]  # 1件目データに対して、チャネル0の出力 
tensor([0.5311, 0.1668], grad_fn=<SelectBackward>)
>>> x[1,0,0:5].dot(conv1d.weight[0,0]) + x[1,1,0:5].dot(conv1d.weight[0,1]) + conv1d.bias[0]
tensor(0.5311, grad_fn=<AddBackward0>)
>>> x[1,0,1:6].dot(conv1d.weight[0,0]) + x[1,1,1:6].dot(conv1d.weight[0,1]) + conv1d.bias[0]
tensor(0.1668, grad_fn=<AddBackward0>)

conv1d.weight[1]conv1d.bias[1]y[:,1](出力チャネル1)に対応します。

>>> y[0,1]  # 0件目データに対して、チャネル1の出力
tensor([-0.1566, -0.0447], grad_fn=<SelectBackward>)
>>> x[0,0,0:5].dot(conv1d.weight[1,0]) + x[0,1,0:5].dot(conv1d.weight[1,1]) + conv1d.bias[1]
tensor(-0.1566, grad_fn=<AddBackward0>)
>>> x[0,0,1:6].dot(conv1d.weight[1,0]) + x[0,1,1:6].dot(conv1d.weight[1,1]) + conv1d.bias[1]
tensor(-0.0447, grad_fn=<AddBackward0>)
>>> y[1,1]  # 1件目データに対して、チャネル1の出力
tensor([-0.4254, -0.0599], grad_fn=<SelectBackward>)
>>> x[1,0,0:5].dot(conv1d.weight[1,0]) + x[1,1,0:5].dot(conv1d.weight[1,1]) + conv1d.bias[1]
tensor(-0.4254, grad_fn=<AddBackward0>)
>>> x[1,0,1:6].dot(conv1d.weight[1,0]) + x[1,1,1:6].dot(conv1d.weight[1,1]) + conv1d.bias[1]
tensor(-0.0599, grad_fn=<AddBackward0>)
15
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?