5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのtochtextで文書分類

Last updated at Posted at 2020-03-23

はじめに

torchtextを利用して文書分類を行う実装の流れを公式のtutorialに沿って説明します。また、公式のtutorialに付随しているGoogle Colabolatryではerrorになっている部分を修正した上でコードを掲載します。最後に、torchtext.datasets.text_classificationのソースコードについて解説を行います。

開発環境

Google Colabolatry

事前知識

N-gramといった、自然言語処理の基礎用語

文書分類の流れ

torchtextを利用して文書分類を行う場合、実装は以下のような流れになります。コードについては次節で見るため、この説では概要だけ記載します。

  1. pip install
  2. moduleのimport
  3. datasetの格納、train, testへの分割
  4. modelの定義
  5. modelのinstance化、batch生成用の関数定義
  6. train, test用の関数定義
  7. train, testの実行

コード

前述の流れをtutorialに載っているコードで確認していきます。
###1. pip install
ほぼ出オチですが、公式ではこのコードが原因でerrorを発生させています。具体的には2行目が原因です。

!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline

このまま実行した場合、後述するmoduleのimportで以下のようなerrorが発生します。

from torchtext.datasets import text_classification

ImportError: cannot import name 'text_classification'

正しいコードは以下のようになります。また、torchtextのversionが変わることでruntimeの初期化が求められることがあります。その際はrestart runtimeを実行し、再度上から順にセルを実行すれば良いです (2回目のpip install後にはrestart runtimeを押す必要は無いです)。

!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline

原因はtorchtextのversionです。何も指定しないでpip installを行うと0.3.1がinstallされてしまいます。text_classificationは0.4以降で実装されているため、0.3のままでは利用できません。なお、上記では0.5に固定していますが0.4以降であれば問題ありません。

###2. moduleのimport

import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os

###3. datasetの格納、train, testへの分割

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###4. modelの定義
embedding → linearというシンプルな流れになっています。また、init_weightでは重みの初期化を一様分布から生成した重みで行なっています。

import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

###5. modelのinstance化、batch生成用の関数定義

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
   label = torch.tensor([entry[0] for entry in batch])
   text = [entry[1] for entry in batch]
   offsets = [0] + [len(entry) for entry in text]
   offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
   text = torch.cat(text)
   return text, offsets, label

###6. train, test用の関数定義

from torch.utils.data import DataLoader

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()
    
    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

###7. train, testの実行
正しく学習できている場合は0.9以上のaccuracyを達成できます。

import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

解説

TORCHTEXT.DATASETS.TEXT_CLASSIFICATIONでは字義通りに必要なデータを提供するための処理が行われています。逆に、それ以外の操作は特に行われていません。つまり、学習に必要なデータの整形を各種データセットに対して行うことがこのmoduleのゴールになります。そのため、今回はtrain, testのデータセットを提供するまでの流れに注目して解説を行います。以下の解説で記載するソースコードはここにあります
まず以下のコードを再掲します。

###3. datasetの格納、train, testへの分割

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ここでは、.dataというディレクトリを作成し、そのディレクトリをrootとして利用した上でtrain, testのデータセットを生成していることがわかります。しかし、これだけでは.dataを含め色々と不明な点があります。そこで、実際にコードを読んでより具体的な処理を見ていきます。

###TORCHTEXT.DATASETS.TEXT_CLASSIFICATIONが提供するデータ
文書分類のためにいくつかのデータが提供されています。現在提供されているデータは以下の通りです。

  • AG_NEWS
  • SogouNews
  • DBpedia
  • YelpReviewPolarity
  • YelpReviewFull
  • YahooAnswers
  • AmazonReviewPolarity
  • AmazonReviewFull

それぞれのデータを直接手に入れたい場合は、URLS変数に記載されているurlからダウンロードすれば良いです。

URLS = {
    'AG_NEWS':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
    'SogouNews':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
    'DBpedia':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
    'YelpReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
    'YelpReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
    'YahooAnswers':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
    'AmazonReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
    'AmazonReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}

では実際にソースコードを通じてデータの処理を追いかけていきます。
最初に行われているのは関数の定義です。

def AG_NEWS(*args, **kwargs):
    """ Defines AG_NEWS datasets.
        The labels includes:
            - 1 : World
            - 2 : Sports
            - 3 : Business
            - 4 : Sci/Tech

    Create supervised learning dataset: AG_NEWS

    Separately returns the training and test dataset

    Arguments:
        root: Directory where the datasets are saved. Default: ".data"
        ngrams: a contiguous sequence of n items from s string text.
            Default: 1
        vocab: Vocabulary used for dataset. If None, it will generate a new
            vocabulary based on the train data set.
        include_unk: include unknown token in the data (Default: False)

    Examples:
        >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)

    """

    return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)

_setup_datasets関数を用いて整形後のデータを返していることがわかります。なお、以降ではAG_NEWSのみを対象にしますが他のデータセットについても同様の処理が行われています。
次に定義した関数をDATASETS変数にdict形式で登録します。

DATASETS = {
    'AG_NEWS': AG_NEWS,
    'SogouNews': SogouNews,
    'DBpedia': DBpedia,
    'YelpReviewPolarity': YelpReviewPolarity,
    'YelpReviewFull': YelpReviewFull,
    'YahooAnswers': YahooAnswers,
    'AmazonReviewPolarity': AmazonReviewPolarity,
    'AmazonReviewFull': AmazonReviewFull
}

また、LABELS変数にdict形式で各データセットごとのラベル情報を格納しています。

LABELS = {
    'AG_NEWS': {1: 'World',
                2: 'Sports',
                3: 'Business',
                4: 'Sci/Tech'},
}

ここでは省略していますが、AG_NEWS以外のデータも同様の形式でラベルが格納されています。
上述のDATASETS変数で関数がdict形式で登録されているため、以下二つは同じものを指します。

text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS

_setup_datasets関数を見てデータの処理を確認します。

def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))

主要な処理は以下の通りです。

  1. download_from_url関数によって指定したディレクトリに文書データを保存する。
  2. build_vocab_from_iterator関数によってデータ内で使用されている単語データを作成する
  3. _create_data_from_iteratorによって文書データからtrain, test用のcsvデータを作成する。
  4. TextClassificationDataset classに単語データ、train(test)データ、train(test)ラベルを渡してインスタンス化し、それらをまとめて返す。
    なお、download_from_url関数はGoogle Drive用に定義されたfileのダウンロードを行う関数です。
    最後にTextClassificationDataset classを見てみます。
class TextClassificationDataset(torch.utils.data.Dataset):
    """Defines an abstract text classification datasets.
       Currently, we only support the following datasets:

             - AG_NEWS
             - SogouNews
             - DBpedia
             - YelpReviewPolarity
             - YelpReviewFull
             - YahooAnswers
             - AmazonReviewPolarity
             - AmazonReviewFull

    """

[docs]    def __init__(self, vocab, data, labels):
        """Initiate text-classification dataset.

        Arguments:
            vocab: Vocabulary object used for dataset.
            data: a list of label/tokens tuple. tokens are a tensor after
                numericalizing the string tokens. label is an integer.
                [(label1, tokens1), (label2, tokens2), (label2, tokens3)]
            label: a set of the labels.
                {label1, label2}

        Examples:
            See the examples in examples/text_classification/

        """

        super(TextClassificationDataset, self).__init__()
        self._data = data
        self._labels = labels
        self._vocab = vocab


    def __getitem__(self, i):
        return self._data[i]

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        for x in self._data:
            yield x

    def get_labels(self):
        return self._labels

    def get_vocab(self):
        return self._vocab

新たにデータの処理を行うものではなく、各データを取り出すためのclassであることがわかります。_setup_datasets関数とTextClassificationDataset classを見ればわかる通り、データセットは生の文書ではなくN-gramに変換されて状態で格納されています。従って、N-gram以外のデータ形式を利用したい場合は.dataに保存されているデータか、URLSに記載されているurlからダウンロードしたデータを元に自前で処理を書く必要があります。

終わりに

printなどをするだけではわかりにくい情報もソースコードを辿ることで理解できますね。今後もソースコードを読んで情報をまとめていきたいと思います。

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?