LoginSignup
1
1

More than 1 year has passed since last update.

torchtextの仕様変更対応 (4) TranslationDataset

Posted at

PyTorch 1.11(torchtext 0.12)より自然言語処理で活用していたtorchtextのFieldやTabularDatasetなど便利な機能がなくなりました。PyTorch 1.10(torchtext 0.11)まではlegacyに移動されいましたが利用することは可能でした。しかし、PyTorch 1.11(torchtext 0.12)で完全に削除されてしまいました。

今回は、翻訳やチャットボットなどSeq2Seqに利用されるTranslationDatasetについて確認していきたいと思います。

基本的な使い方

英日翻訳を例に説明します。
英語、日本語のそれぞれの対訳ファイルを用意します。ファイル名のサフィックスで言語を識別できるようにしておきます。

以下のファイルを例に基本的な利用方法を見ていきます。

test.en
this is a pen .
that is a pen , too .
where is the pen ?
test.ja
これ は ペン です 。
あれ も ペン です 。
ペン は どこ です か 。

すでに形態素解析済の対訳としております。英語は、すべて小文字としピリオドなども1単語としスペースで区切っています。

旧実装

データ読み込み

TranslationDatasetを利用すれば簡単に対訳ファイルを読み込むことができます。
TabularDataset同様にFieldを定義します。英語、日本語それぞれ生成します。Seq2Seq用に先頭に<s>、最後に</s>を追加するように設定しています。今回は、形態解析済ですが、形態素解析したい場合は、tokenizeに形態素解析を行う関数を設定してください。
TranslationDatasetには、サフィックスを除いたファイルのパス、サフィックスのペアをextsに、Fieldのペアfieldsに指定します。
これだけでデータセットとして読み込んでくれます。

from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.datasets import TranslationDataset

# Field定義
# 英語
TEXT_EN  = Field(init_token='<s>', eos_token='</s>', batch_first=True)
# 日本語
TEXT_JA  = Field(init_token='<s>', eos_token='</s>', batch_first=True)
# データ読み込み
ds = TranslationDataset('test', exts=('.en', '.ja'), fields=(TEXT_EN, TEXT_JA))

データセットの中身を確認してみましょう。

for i in range(len(ds)):
    print(ds[i].src, ds[i].trg)
['this', 'is', 'a', 'pen', '.'] ['これ', 'は', 'ペン', 'です', '。']
['that', 'is', 'a', 'pen', ',', 'too', '.'] ['あれ', 'も', 'ペン', 'です', '。']
['where', 'is', 'the', 'pen', '?'] ['ペン', 'は', 'どこ', 'です', 'か', '。']

英日ペアが格納されていることがわかります。

単語辞書作成

英語、日本語の単語辞書を作成します。

# 単語辞書作成
# 英語
TEXT_EN.build_vocab(ds)
# 日本語
TEXT_JA.build_vocab(ds)

英語辞書の中身と出現回数です。

TEXT_EN.vocab.stoi
            {'<unk>': 0,
             '<pad>': 1,
             '<s>': 2,
             '</s>': 3,
             'is': 4,
             'pen': 5,
             '.': 6,
             'a': 7,
             ',': 8,
             '?': 9,
             'that': 10,
             'the': 11,
             'this': 12,
             'too': 13,
             'where': 14})
TEXT_EN.vocab.freqs
Counter({'this': 1,
         'is': 3,
         'a': 2,
         'pen': 3,
         '.': 2,
         'that': 1,
         ',': 1,
         'too': 1,
         'where': 1,
         'the': 1,
         '?': 1})

日本語辞書の中身と出現回数です。

TEXT_JA.vocab.stoi
            {'<unk>': 0,
             '<pad>': 1,
             '<s>': 2,
             '</s>': 3,
             '。': 4,
             'です': 5,
             'ペン': 6,
             'は': 7,
             'あれ': 8,
             'か': 9,
             'これ': 10,
             'どこ': 11,
             'も': 12})
TEXT_JA.vocab.freqs
Counter({'これ': 1,
         'は': 2,
         'ペン': 3,
         'です': 3,
         '。': 3,
         'あれ': 1,
         'も': 1,
         'どこ': 1,
         'か': 1})

ミニバッチ学習

ミニバッチ学習用にイテレータを生成します。DataLoaderに相当します。
エポックごとにシャッフルするように、shuffle=Trueを設定しています。(train=Trueと設定することも可能)

# イテレータ生成
biter = BucketIterator(dataset=ds, shuffle=True, batch_size=3)

各バッチごとのデータを確認してみます。

for i, batch in enumerate(biter):
    print(i)
    for en, ja in zip(batch.src, batch.trg):
        print(en, ja)
0
tensor([ 2, 14,  4, 11,  5,  9,  3,  1,  1]) tensor([ 2,  6,  7, 11,  5,  9,  4,  3])
tensor([ 2, 12,  4,  7,  5,  6,  3,  1,  1]) tensor([ 2, 10,  7,  6,  5,  4,  3,  1])
tensor([ 2, 10,  4,  7,  5,  8, 13,  6,  3]) tensor([ 2,  8, 12,  6,  5,  4,  3,  1])

単語数値化されています。
これだとわかりにくいので単語に変換してみます。

for i, batch in enumerate(biter):
    print(i)
    for en, ja in zip(batch.src, batch.trg):
        print([TEXT_EN.vocab.itos[e] for e in en], [TEXT_JA.vocab.itos[j] for j in ja])
0
['<s>', 'this', 'is', 'a', 'pen', '.', '</s>', '<pad>', '<pad>'] ['<s>', 'これ', 'は', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'that', 'is', 'a', 'pen', ',', 'too', '.', '</s>'] ['<s>', 'あれ', 'も', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'where', 'is', 'the', 'pen', '?', '</s>', '<pad>', '<pad>'] ['<s>', 'ペン', 'は', 'どこ', 'です', 'か', '。', '</s>']

最初と最後に<s>、</s>が追加されていることがわかります。また、シーケンス長を合わせるため不足単語は<pad>として補填されています。

もう2回実行してみます。

0
['<s>', 'that', 'is', 'a', 'pen', ',', 'too', '.', '</s>'] ['<s>', 'あれ', 'も', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'where', 'is', 'the', 'pen', '?', '</s>', '<pad>', '<pad>'] ['<s>', 'ペン', 'は', 'どこ', 'です', 'か', '。', '</s>']
['<s>', 'this', 'is', 'a', 'pen', '.', '</s>', '<pad>', '<pad>'] ['<s>', 'これ', 'は', 'ペン', 'です', '。', '</s>', '<pad>']
0
['<s>', 'this', 'is', 'a', 'pen', '.', '</s>', '<pad>', '<pad>'] ['<s>', 'これ', 'は', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'that', 'is', 'a', 'pen', ',', 'too', '.', '</s>'] ['<s>', 'あれ', 'も', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'where', 'is', 'the', 'pen', '?', '</s>', '<pad>', '<pad>'] ['<s>', 'ペン', 'は', 'どこ', 'です', 'か', '。', '</s>']

まとめ

これだけで、ファイルの読み込みから辞書作成、ミニバッチ学習用のデータ生成まで行ってくれます。非常に便利でした。

from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.datasets import TranslationDataset

# Field定義
# 英語
TEXT_EN  = Field(init_token='<s>', eos_token='</s>', batch_first=True)
# 日本語
TEXT_JA  = Field(init_token='<s>', eos_token='</s>', batch_first=True)
# データ読み込み
ds = TranslationDataset('test', exts=('.en', '.ja'), fields=(TEXT_EN, TEXT_JA))
# 単語辞書作成
# 英語
TEXT_EN.build_vocab(ds)
# 日本語
TEXT_JA.build_vocab(ds)
# イテレータ生成
biter = BucketIterator(dataset=ds, shuffle=True, batch_size=3)

新実装

Field,TranslationDataset,BucketIteratorを使わずに実装していきます。

データ読み込み

TorchDataを利用し実装していきます。TorchDataは、torchtext 0.13より標準で利用するようです。

データパイプを接続しながら読み込みを行います。
ファイルをオープンするFileOpener、ファイルを1行ずつ読み込むLineReaderを利用します。
FileOpenerの第一引数は、datapipeを指定する必要があるため、単一ファイルパスを渡すことができないようです。リストとしてファイルパスを渡します。
LineReaderでファイルを1行ずつ読み込みます。既定値では、ファイルパスも返却されます。ここではファイルパスは不要なためreturn_path=Falseを設定します。
英語、日本語ファイルそれぞれdatapipeを作成します。

import torchdata.datapipes as dp

# データ読み込み(英語)
datapipe_en = dp.iter.FileOpener(['test.en'], mode='rt')
datapipe_en = dp.iter.LineReader(datapipe_en, return_path=False)
# データ読み込み(日本語)
datapipe_ja = dp.iter.FileOpener(['test.ja'], mode='rt')
datapipe_ja = dp.iter.LineReader(datapipe_ja, return_path=False)

データが読み込めたか確認します。

for text_en in datapipe_en:
    print(text_en)
this is a pen .
that is a pen , too .
where is the pen ?
for text_ja in datapipe_ja:
    print(text_ja)
これ は ペン です 。
あれ も ペン です 。
ペン は どこ です か 。

問題なく読み込めています。

単語分割

次に、単語分割を行います。ここでは、単にスペースで区切るだけのためget_tokernizerを用いました。
Mapperで変換を行います。変換を行う関数を定義します。ここではlambda式で関数を定義しています。
ここでもdatapipeをつなげるだけです。

# tokenizer設定
tokenizer = get_tokenizer(tokenizer=None)
# 単語分割
datapipe_en = dp.iter.Mapper(datapipe_en, lambda text: tokenizer(text))
datapipe_ja = dp.iter.Mapper(datapipe_ja, lambda text: tokenizer(text))

実行後にlambda式を利用しているためワーニングが表示されます。TorchDataのドキュメントにも記載されているので無視して構わないでしょう。
データを確認します。

for text_en in datapipe_en:
    print(text_en)
['this', 'is', 'a', 'pen', '.']
['that', 'is', 'a', 'pen', ',', 'too', '.']
['where', 'is', 'the', 'pen', '?']
for text_ja in datapipe_ja:
    print(text_ja)
['これ', 'は', 'ペン', 'です', '。']
['あれ', 'も', 'ペン', 'です', '。']
['ペン', 'は', 'どこ', 'です', 'か', '。']

単語に分割されていることがわかります。

単語辞書作成

辞書作成には、build_vocab_from_iteratorを利用します。datapipeを渡せば大丈夫です。英語、日本語の辞書を作成します。
特殊文字として、<unk>、<pad>とともに、Seq2Seq用に文章の先頭を表す<s>、文章の最後を表す</s>も設定しています。辞書に含まれない単語用に<unk>をdefault_indexとして設定します。

from torchtext.vocab import build_vocab_from_iterator

# 単語辞書作成(英語)
en_vocab = build_vocab_from_iterator(datapipe_en, specials=('<unk>', '<pad>', '<s>', '</s>'))
en_vocab.set_default_index(en_vocab['<unk>'])
# 単語辞書作成(日本語)
ja_vocab = build_vocab_from_iterator(datapipe_ja, specials=('<unk>', '<pad>', '<s>', '</s>'))
ja_vocab.set_default_index(ja_vocab['<unk>'])

辞書の内容を確認します。

en_vocab.get_stoi()
{'this': 12,
 'the': 11,
 '<s>': 2,
 '<pad>': 1,
 'that': 10,
 '<unk>': 0,
 'too': 13,
 '</s>': 3,
 'is': 4,
 'a': 7,
 'pen': 5,
 'where': 14,
 '?': 9,
 '.': 6,
 ',': 8}
ja_vocab.get_stoi()
{'も': 12,
 'どこ': 11,
 'これ': 10,
 '</s>': 3,
 '<s>': 2,
 'ペン': 6,
 '<unk>': 0,
 '。': 4,
 'か': 9,
 'です': 5,
 'は': 7,
 '<pad>': 1,
 'あれ': 8}

transform設定

文章の変換方法を定義します。以前は、Fieldを定義しておけば、BucketIteratorで行ってくれていましたが、自前で変換する必要があります。
辞書による変換、<s>、</s>の挿入、パディング、Tensor型への変換を行います。パディングは、ミニバッチごとに系列長を統一するため不足部分がパディングされます。

import torchtext.transforms as T

# transform生成(英語)
en_transform = T.Sequential(
    T.VocabTransform(en_vocab),
    T.AddToken(en_vocab['<s>'], begin=True),
    T.AddToken(en_vocab['</s>'], begin=False),
    T.ToTensor(padding_value=en_vocab['<pad>'])
)
# transform生成(日本語)
ja_transform = T.Sequential(
    T.VocabTransform(ja_vocab),
    T.AddToken(ja_vocab['<s>'], begin=True),
    T.AddToken(ja_vocab['</s>'], begin=False),
    T.ToTensor(padding_value=ja_vocab['<pad>'])
)

データセット作成

英語、日本語をペアにするデータセットを作成します。
Zipperを用いて、英語、日本語をペアにします。

# 英語・日本語ペアに
datapipe = dp.iter.Zipper(datapipe_en, datapipe_ja)

英語、日本語ペアになっているか確認します。

for text in datapipe:
    print(text)
(['this', 'is', 'a', 'pen', '.'], ['これ', 'は', 'ペン', 'です', '。'])
(['that', 'is', 'a', 'pen', ',', 'too', '.'], ['あれ', 'も', 'ペン', 'です', '。'])
(['where', 'is', 'the', 'pen', '?'], ['ペン', 'は', 'どこ', 'です', 'か', '。'])

ミニバッチ学習

まずは、ミニバッチごとに辞書を用いてデータを変換する関数を定義します。

# ミニバッチ時のデータ変換関数
def collate_batch(batch):
    ens = en_transform([src for (src, trg) in batch])
    jas = ja_transform([trg for (src, trg) in batch])
    return ens, jas

基本的には、DataLoaderにdatapipeを渡せばよいのですが、datapipeでは順番にデータを読み込むだけでシャッフルが行えません。そこで、to_map_style_datasetを用いてmapに変換します。これでいったんデータをすべて読み込むのでシャッフルに対応できます。

from torch.utils.data import DataLoader
from torchtext.data.functional import to_map_style_dataset

# mapに変換
ds = to_map_style_dataset(datapipe)
# DataLoader設定
data_loader = DataLoader(ds, shuffle=True, batch_size=3, collate_fn=collate_batch)

各バッチごとのデータを確認してみます。

for i, (ens, jas) in enumerate(data_loader):
    print(i)
    for en, ja in zip(ens, jas):
        print(en, ja)
0
tensor([ 2, 12,  4,  7,  5,  6,  3,  1,  1]) tensor([ 2, 10,  7,  6,  5,  4,  3,  1])
tensor([ 2, 14,  4, 11,  5,  9,  3,  1,  1]) tensor([ 2,  6,  7, 11,  5,  9,  4,  3])
tensor([ 2, 10,  4,  7,  5,  8, 13,  6,  3]) tensor([ 2,  8, 12,  6,  5,  4,  3,  1])

これだと分かりにくいので辞書を使って単語に変換します。

for i, (ens, jas) in enumerate(data_loader):
    print(i)
    for en, ja in zip(ens, jas):
        print(en_vocab.lookup_tokens(en.numpy()), ja_vocab.lookup_tokens(ja.numpy()))
0
['<s>', 'where', 'is', 'the', 'pen', '?', '</s>', '<pad>', '<pad>'] ['<s>', 'ペン', 'は', 'どこ', 'です', 'か', '。', '</s>']
['<s>', 'that', 'is', 'a', 'pen', ',', 'too', '.', '</s>'] ['<s>', 'あれ', 'も', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'this', 'is', 'a', 'pen', '.', '</s>', '<pad>', '<pad>'] ['<s>', 'これ', 'は', 'ペン', 'です', '。', '</s>', '<pad>']

文章の先頭に<s>、最後に</s>が追加されていることがわかります。また、文書の単語数をそろえるため不足分は<pad>が追加されています。

もう2回実行します。

0
['<s>', 'this', 'is', 'a', 'pen', '.', '</s>', '<pad>', '<pad>'] ['<s>', 'これ', 'は', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'where', 'is', 'the', 'pen', '?', '</s>', '<pad>', '<pad>'] ['<s>', 'ペン', 'は', 'どこ', 'です', 'か', '。', '</s>']
['<s>', 'that', 'is', 'a', 'pen', ',', 'too', '.', '</s>'] ['<s>', 'あれ', 'も', 'ペン', 'です', '。', '</s>', '<pad>']
0
['<s>', 'that', 'is', 'a', 'pen', ',', 'too', '.', '</s>'] ['<s>', 'あれ', 'も', 'ペン', 'です', '。', '</s>', '<pad>']
['<s>', 'where', 'is', 'the', 'pen', '?', '</s>', '<pad>', '<pad>'] ['<s>', 'ペン', 'は', 'どこ', 'です', 'か', '。', '</s>']
['<s>', 'this', 'is', 'a', 'pen', '.', '</s>', '<pad>', '<pad>'] ['<s>', 'これ', 'は', 'ペン', 'です', '。', '</s>', '<pad>']

シャッフルもできていますね。

まとめ

datapipeは、クラスを利用し定義してきましたが、関数を利用する方が推奨されています。それぞれ関数がありますのでドキュメントを参照してください。

# データ読み込み、単語分割(英語)
datapipe_en = dp.iter.FileOpener(['test.en'], mode='rt')
datapipe_en = dp.iter.LineReader(datapipe_en, return_path=False)
datapipe_en = dp.iter.Mapper(datapipe_en, lambda text: tokenizer(text))
# データ読み込み、単語分割(日本語)
datapipe_ja = dp.iter.FileOpener(['test.ja'], mode='rt')
datapipe_ja = dp.iter.LineReader(datapipe_ja, return_path=False)
datapipe_ja = dp.iter.Mapper(datapipe_ja, lambda text: tokenizer(text))
# 英語・日本語ペアに
datapipe = dp.iter.Zipper(datapipe_en, datapipe_ja)

関数で書くと以下のようになります。

# データ読み込み、単語分割(英語)
datapipe_en = dp.iter.FileOpener(['test.en'], mode='rt'). \
                   readlines(return_path=False). \
                   map(lambda text: tokenizer(text))
# データ読み込み、単語分割(日本語)
datapipe_ja = dp.iter.FileOpener(['test.ja'], mode='rt'). \
                   readlines(return_path=False). \
                   map(lambda text: tokenizer(text))
# 英語・日本語ペアに
datapipe = datapipe_en.zip(datapipe_ja)

すべてのプログラムです。

import torchdata.datapipes as dp
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.utils.data import DataLoader
from torchtext.data.functional import to_map_style_dataset

# tokenizer設定
tokenizer = get_tokenizer(tokenizer=None)
# データ読み込み、単語分割(英語)
datapipe_en = dp.iter.FileOpener(['test.en'], mode='rt'). \
                   readlines(return_path=False). \
                   map(lambda text: tokenizer(text))
# データ読み込み、単語分割(日本語)
datapipe_ja = dp.iter.FileOpener(['test.ja'], mode='rt'). \
                   readlines(return_path=False). \
                   map(lambda text: tokenizer(text))
# 英語・日本語ペアに
datapipe = datapipe_en.zip(datapipe_ja)
# 単語辞書作成(英語)
en_vocab = build_vocab_from_iterator(datapipe_en, specials=('<unk>', '<pad>', '<s>', '</s>'))
en_vocab.set_default_index(en_vocab['<unk>'])
# 単語辞書作成(日本語)
ja_vocab = build_vocab_from_iterator(datapipe_ja, specials=('<unk>', '<pad>', '<s>', '</s>'))
ja_vocab.set_default_index(ja_vocab['<unk>'])
# transform生成
en_transform = T.Sequential(
    T.VocabTransform(en_vocab),
    T.AddToken(en_vocab['<s>'], begin=True),
    T.AddToken(en_vocab['</s>'], begin=False),
    T.ToTensor(padding_value=en_vocab['<pad>'])
)
ja_transform = T.Sequential(
    T.VocabTransform(ja_vocab),
    T.AddToken(ja_vocab['<s>'], begin=True),
    T.AddToken(ja_vocab['</s>'], begin=False),
    T.ToTensor(padding_value=ja_vocab['<pad>'])
)
# ミニバッチ時のデータ変換関数
def collate_batch(batch):
    ens = en_transform([src for (src, trg) in batch])
    jas = ja_transform([trg for (src, trg) in batch])
    return ens, jas
# mapに変換
ds = to_map_style_dataset(datapipe)
# DataLoader設定
data_loader = DataLoader(ds, shuffle=True, batch_size=3, collate_fn=collate_batch)

簡潔に書いたつもりですが、かなり長くなりました。旧実装と比較すると一目瞭然です。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1