6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのpack_padded_sequenceの効能

Last updated at Posted at 2019-12-24

記事の内容

  • 長さの異なる系列をミニバッチ化する際には,長さを揃えるためにpaddingが必要になる.
  • paddingの有無によって出力が変化しないように実装することが目的.
  • PyTorchのtorch.nn.utils.rnn.pack_padded_sequenceの効能について,単方向/双方向LSTMで検証する.

実行環境

  • PyTorch 1.3.1

結論から

単方向LSTMでも,双方向LSTMでも,padding後に
torch.nn.utils.rnn.pack_padded_sequenceを適用すればOK.

ソースコード

padding.py
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

vocab_size = 3
emb_size = 1
hidden_size = 1
num_layers = 1
num_direction = 1
batch_size = 2

torch.manual_seed(42)

emb = nn.Embedding(vocab_size, emb_size, padding_idx=0)
emb.weight = nn.Parameter(torch.tensor([[float(i)] for i in range(vocab_size)]))
rnn = nn.LSTM(input_size=emb_size,
              hidden_size=hidden_size,
              num_layers=num_layers,
              bidirectional=num_direction == 2)

print('# ----- Input -----')
X = [[1, 1], [2, 2, 2]]
X_len = torch.tensor([len(x) for x in X])
X = [torch.tensor(x) for x in X]
print(f'ORIGINAL\n{X}')
print(f'LENGTH\n{X_len}')
padded_X = pad_sequence(X)
print(f'PADDED\n{padded_X}')
print()

print('# ----- Output (without packing) -----')
emb_X = emb(padded_X)
rnn_out, hidden = rnn(emb_X)
print(f'RNN_OUT\n{rnn_out}')
print(f'HIDDEN[0]\n{hidden[0]}')
print()

print('# ----- Output (with packing) -----')
packed_emb_X = pack_padded_sequence(emb_X, X_len, enforce_sorted=False)
packed_rnn_out, hidden = rnn(packed_emb_X)
rnn_out, seq_len = pad_packed_sequence(packed_rnn_out)
print(f'RNN_OUT\n{rnn_out}')
print(f'HIDDEN[0]\n{hidden[0]}')

単方向の場合

hiddenには各系列のpaddingの手前までの隠れ状態ベクトルが格納されている.

stdout
# ----- Input -----
ORIGINAL
[tensor([1, 1]), tensor([2, 2, 2])]
LENGTH
tensor([2, 3])
PADDED
tensor([[1, 2],
        [1, 2],
        [0, 2]])

# ----- Output (without packing) -----
RNN_OUT
tensor([[[-0.0022],
         [ 0.4329]],

        [[-0.0035],
         [ 0.5349]],

        [[-0.2764],
         [ 0.5595]]], grad_fn=<StackBackward>)
HIDDEN[0]
tensor([[[-0.2764],
         [ 0.5595]]], grad_fn=<StackBackward>)

# ----- Output (with packing) -----
RNN_OUT
tensor([[[-0.0022],
         [ 0.4329]],

        [[-0.0035],
         [ 0.5349]],

        [[ 0.0000],
         [ 0.5595]]], grad_fn=<IndexSelectBackward>)
HIDDEN[0]
tensor([[[-0.0035],
         [ 0.5595]]], grad_fn=<IndexSelectBackward>)

双方向の場合

stdout
(前略)
# ----- Output (without packing) -----
RNN_OUT
tensor([[[-0.0022, -0.2293],
         [ 0.4329, -0.1241]],

        [[-0.0035, -0.1591],
         [ 0.5349, -0.1016]],

        [[-0.2764, -0.0868],
         [ 0.5595, -0.0636]]], grad_fn=<CatBackward>)
HIDDEN[0]
tensor([[[-0.2764],
         [ 0.5595]],

        [[-0.2293],
         [-0.1241]]], grad_fn=<StackBackward>)

# ----- Output (with packing) -----
RNN_OUT
tensor([[[-0.0022, -0.1858],
         [ 0.4329, -0.1241]],

        [[-0.0035, -0.1027],
         [ 0.5349, -0.1016]],

        [[ 0.0000,  0.0000],
         [ 0.5595, -0.0636]]], grad_fn=<IndexSelectBackward>)
HIDDEN[0]
tensor([[[-0.0035],
         [ 0.5595]],

        [[-0.1858],
         [-0.1241]]], grad_fn=<IndexSelectBackward>)
6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?