はじめに
torchtextを利用して文書分類を行う実装の流れを公式のtutorialに沿って説明します。また、公式のtutorialに付随しているGoogle Colabolatryではerrorになっている部分を修正した上でコードを掲載します。最後に、torchtext.datasets.text_classificationのソースコードについて解説を行います。
開発環境
Google Colabolatry
事前知識
N-gramといった、自然言語処理の基礎用語
文書分類の流れ
torchtextを利用して文書分類を行う場合、実装は以下のような流れになります。コードについては次節で見るため、この説では概要だけ記載します。
- pip install
- moduleのimport
- datasetの格納、train, testへの分割
- modelの定義
- modelのinstance化、batch生成用の関数定義
- train, test用の関数定義
- train, testの実行
コード
前述の流れをtutorialに載っているコードで確認していきます。
###1. pip install
ほぼ出オチですが、公式ではこのコードが原因でerrorを発生させています。具体的には2行目が原因です。
!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline
このまま実行した場合、後述するmoduleのimportで以下のようなerrorが発生します。
from torchtext.datasets import text_classification
ImportError: cannot import name 'text_classification'
正しいコードは以下のようになります。また、torchtextのversionが変わることでruntimeの初期化が求められることがあります。その際はrestart runtimeを実行し、再度上から順にセルを実行すれば良いです (2回目のpip install後にはrestart runtimeを押す必要は無いです)。
!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline
原因はtorchtextのversionです。何も指定しないでpip installを行うと0.3.1がinstallされてしまいます。text_classificationは0.4以降で実装されているため、0.3のままでは利用できません。なお、上記では0.5に固定していますが0.4以降であれば問題ありません。
###2. moduleのimport
import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
###3. datasetの格納、train, testへの分割
if not os.path.isdir('./.data'):
os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
###4. modelの定義
embedding → linearというシンプルな流れになっています。また、init_weightでは重みの初期化を一様分布から生成した重みで行なっています。
import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
###5. modelのinstance化、batch生成用の関数定義
VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
def generate_batch(batch):
label = torch.tensor([entry[0] for entry in batch])
text = [entry[1] for entry in batch]
offsets = [0] + [len(entry) for entry in text]
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text = torch.cat(text)
return text, offsets, label
###6. train, test用の関数定義
from torch.utils.data import DataLoader
def train_func(sub_train_):
# Train the model
train_loss = 0
train_acc = 0
data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=generate_batch)
for i, (text, offsets, cls) in enumerate(data):
optimizer.zero_grad()
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
loss = criterion(output, cls)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_acc += (output.argmax(1) == cls).sum().item()
# Adjust the learning rate
scheduler.step()
return train_loss / len(sub_train_), train_acc / len(sub_train_)
def test(data_):
loss = 0
acc = 0
data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
for text, offsets, cls in data:
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
with torch.no_grad():
output = model(text, offsets)
loss = criterion(output, cls)
loss += loss.item()
acc += (output.argmax(1) == cls).sum().item()
return loss / len(data_), acc / len(data_)
###7. train, testの実行
正しく学習できている場合は0.9以上のaccuracyを達成できます。
import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
random_split(train_dataset, [train_len, len(train_dataset) - train_len])
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train_func(sub_train_)
valid_loss, valid_acc = test(sub_valid_)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
解説
TORCHTEXT.DATASETS.TEXT_CLASSIFICATIONでは字義通りに必要なデータを提供するための処理が行われています。逆に、それ以外の操作は特に行われていません。つまり、学習に必要なデータの整形を各種データセットに対して行うことがこのmoduleのゴールになります。そのため、今回はtrain, testのデータセットを提供するまでの流れに注目して解説を行います。以下の解説で記載するソースコードはここにあります。
まず以下のコードを再掲します。
###3. datasetの格納、train, testへの分割
if not os.path.isdir('./.data'):
os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ここでは、.dataというディレクトリを作成し、そのディレクトリをrootとして利用した上でtrain, testのデータセットを生成していることがわかります。しかし、これだけでは.dataを含め色々と不明な点があります。そこで、実際にコードを読んでより具体的な処理を見ていきます。
###TORCHTEXT.DATASETS.TEXT_CLASSIFICATIONが提供するデータ
文書分類のためにいくつかのデータが提供されています。現在提供されているデータは以下の通りです。
- AG_NEWS
- SogouNews
- DBpedia
- YelpReviewPolarity
- YelpReviewFull
- YahooAnswers
- AmazonReviewPolarity
- AmazonReviewFull
それぞれのデータを直接手に入れたい場合は、URLS変数に記載されているurlからダウンロードすれば良いです。
URLS = {
'AG_NEWS':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
'SogouNews':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
'DBpedia':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
'YelpReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
'YelpReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
'YahooAnswers':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
'AmazonReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
'AmazonReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}
では実際にソースコードを通じてデータの処理を追いかけていきます。
最初に行われているのは関数の定義です。
def AG_NEWS(*args, **kwargs):
""" Defines AG_NEWS datasets.
The labels includes:
- 1 : World
- 2 : Sports
- 3 : Business
- 4 : Sci/Tech
Create supervised learning dataset: AG_NEWS
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
ngrams: a contiguous sequence of n items from s string text.
Default: 1
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
include_unk: include unknown token in the data (Default: False)
Examples:
>>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)
"""
return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)
_setup_datasets関数を用いて整形後のデータを返していることがわかります。なお、以降ではAG_NEWSのみを対象にしますが他のデータセットについても同様の処理が行われています。
次に定義した関数をDATASETS変数にdict形式で登録します。
DATASETS = {
'AG_NEWS': AG_NEWS,
'SogouNews': SogouNews,
'DBpedia': DBpedia,
'YelpReviewPolarity': YelpReviewPolarity,
'YelpReviewFull': YelpReviewFull,
'YahooAnswers': YahooAnswers,
'AmazonReviewPolarity': AmazonReviewPolarity,
'AmazonReviewFull': AmazonReviewFull
}
また、LABELS変数にdict形式で各データセットごとのラベル情報を格納しています。
LABELS = {
'AG_NEWS': {1: 'World',
2: 'Sports',
3: 'Business',
4: 'Sci/Tech'},
}
ここでは省略していますが、AG_NEWS以外のデータも同様の形式でラベルが格納されています。
上述のDATASETS変数で関数がdict形式で登録されているため、以下二つは同じものを指します。
text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS
_setup_datasets関数を見てデータの処理を確認します。
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)
for fname in extracted_files:
if fname.endswith('train.csv'):
train_csv_path = fname
if fname.endswith('test.csv'):
test_csv_path = fname
if vocab is None:
logging.info('Building Vocab based on {}'.format(train_csv_path))
vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
logging.info('Vocab has {} entries'.format(len(vocab)))
logging.info('Creating training data')
train_data, train_labels = _create_data_from_iterator(
vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
logging.info('Creating testing data')
test_data, test_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
if len(train_labels ^ test_labels) > 0:
raise ValueError("Training and test labels don't match")
return (TextClassificationDataset(vocab, train_data, train_labels),
TextClassificationDataset(vocab, test_data, test_labels))
主要な処理は以下の通りです。
- download_from_url関数によって指定したディレクトリに文書データを保存する。
- build_vocab_from_iterator関数によってデータ内で使用されている単語データを作成する
- _create_data_from_iteratorによって文書データからtrain, test用のcsvデータを作成する。
- TextClassificationDataset classに単語データ、train(test)データ、train(test)ラベルを渡してインスタンス化し、それらをまとめて返す。
なお、download_from_url関数はGoogle Drive用に定義されたfileのダウンロードを行う関数です。
最後にTextClassificationDataset classを見てみます。
class TextClassificationDataset(torch.utils.data.Dataset):
"""Defines an abstract text classification datasets.
Currently, we only support the following datasets:
- AG_NEWS
- SogouNews
- DBpedia
- YelpReviewPolarity
- YelpReviewFull
- YahooAnswers
- AmazonReviewPolarity
- AmazonReviewFull
"""
[docs] def __init__(self, vocab, data, labels):
"""Initiate text-classification dataset.
Arguments:
vocab: Vocabulary object used for dataset.
data: a list of label/tokens tuple. tokens are a tensor after
numericalizing the string tokens. label is an integer.
[(label1, tokens1), (label2, tokens2), (label2, tokens3)]
label: a set of the labels.
{label1, label2}
Examples:
See the examples in examples/text_classification/
"""
super(TextClassificationDataset, self).__init__()
self._data = data
self._labels = labels
self._vocab = vocab
def __getitem__(self, i):
return self._data[i]
def __len__(self):
return len(self._data)
def __iter__(self):
for x in self._data:
yield x
def get_labels(self):
return self._labels
def get_vocab(self):
return self._vocab
新たにデータの処理を行うものではなく、各データを取り出すためのclassであることがわかります。_setup_datasets関数とTextClassificationDataset classを見ればわかる通り、データセットは生の文書ではなくN-gramに変換されて状態で格納されています。従って、N-gram以外のデータ形式を利用したい場合は.dataに保存されているデータか、URLSに記載されているurlからダウンロードしたデータを元に自前で処理を書く必要があります。
終わりに
printなどをするだけではわかりにくい情報もソースコードを辿ることで理解できますね。今後もソースコードを読んで情報をまとめていきたいと思います。