LoginSignup
0
2

More than 5 years have passed since last update.

Chainer サンプルの train_ptb.py を自前のデータセットで学習させるときのメモ

Posted at

環境

  • Windows 10
  • Python 3.5.2
  • Chainer 3.1.0

train_ptb.py?

  • LSTM を用いた RNN 言語モデルを生成するスクリプト
    • 過去の単語列から次に来るであろう単語を予測する
  • Chainer のほか TensorFlow にもサンプルがある
  • Penn Tree Bank データセットを使うから ptb らしい (TensorFlow のチュートリアル より)

学習編

学習には train_ptb.py を用いる。

自前のデータセットが使えない

事象

デフォルトでは、デモ用の以下サンプルデータセット (Penn Tree Bank) で学習・評価される。

サンプルではなく自前のデータセットで学習させたい場合は、以下記事に書かれているように Chainer の example/ptb ディレクトリに置けば良さそう。

生成された文章を適当にptb.test.txtとptb.train.txtとptb.valid.txtに分割します。
trainを一番多くしてほかは1割ずつくらいに割り当てました。
chainerのexampleのptbフォルダに3つとも移します。
http://looseleaf0727.hatenablog.jp/entry/2016/09/04/014522

しかし、自身の環境では自前のデータセットは使われず、相変わらずサンプルのデータセットが学習に使われてしまう。

vocab=10000はサンプルのボキャブラリー数
$ python ./train_ptb.py --epoch 1
#vocab = 10000
  :

解決策

ソース を読んだところ、一度サンプルをダウンロードするとキャッシュ化されるため、キャッシュが存在する限りはこれが優先して使われる模様。
そこでサンプルのデータセットが保存されるキャッシュディレクトリを探し、自前のデータセットで上書きすることで無理やり自前のそれを使わせるようにする。

キャッシュディレクトリを探す
$ python
Python 3.5.2 |Anaconda 4.1.1 (64-bit)| (default, Jul  5 2016, 11:41:13) [MSC v.1900 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import os
>>> import chainer
>>> from chainer.dataset import download
>>> os.path.join(download.get_dataset_root(), '_dl_cache')
'C:\\Users\\Tsubasa Ogawa/.chainer/dataset\\_dl_cache' # <- キャッシュディレクトリ
>>> download.get_dataset_directory('pfnet/chainer/ptb')
'C:\\Users\\Tsubasa Ogawa/.chainer/dataset\\pfnet/chainer/ptb' # <- データセットディレクトリ (あとで使う)

キャッシュディレクトリには、 URL をハッシュ化したファイル名で train, valid, test それぞれが格納されている。

キャッシュファイルの確認
$ ls -l C\:/Users/Tsubasa\ Ogawa/.chainer/dataset/_dl_cache/
total 5820
-rw-r--r-- 1 Tsubasa Ogawa 197609 5101618 1月   9 22:48 54a4921558c56e3d6f4a4d0b32b2c728 # <- train
-rw-r--r-- 1 Tsubasa Ogawa 197609  449945 1月   9 22:48 73a79a5e1ef9cb62bc3930af372954c9 # <- valid
-rw-r--r-- 1 Tsubasa Ogawa 197609  399782 1月   9 22:48 e058d9dd88b1578ff9cafeb03b711b71 # <- test

これらを自前のデータセットで上書きすればOK。

自前のファイルで上書き
$ cp -p ptb.train.txt C\:/Users/Tsubasa\ Ogawa/.chainer/dataset/_dl_cache/54a4921558c56e3d6f4a4d0b32b2c728
$ cp -p ptb.valid.txt C\:/Users/Tsubasa\ Ogawa/.chainer/dataset/_dl_cache/73a79a5e1ef9cb62bc3930af372954c9
$ cp -p ptb.test.txt C\:/Users/Tsubasa\ Ogawa/.chainer/dataset/_dl_cache/e058d9dd88b1578ff9cafeb03b711b71 
リベンジ
# サンプルをもとに作成されたファイル類 (データセットディレクトリに保存される) は綺麗にしておく
$ rm C\:/Users/Tsubasa\ Ogawa/.chainer/dataset\\pfnet/chainer/ptb/*

# 学習実行
$ python ./train_ptb.py --epoch 1
   :

UnicodeDecodeError

事象

自前のデータセットが使えるようになったが、デコードエラーが発生した。


$ python ./train_ptb.py --epoch 1
Traceback (most recent call last):
  File "./train_ptb.py", line 278, in <module>
    main()
    :
  File "C:\Program Files\Anaconda3\lib\site-packages\chainer\datasets\ptb.py", line 103, in _load_words
    for line in words_file:
UnicodeDecodeError: 'cp932' codec can't decode byte 0x8a in position 2: illegal multibyte sequence

解決策

cp932 がデフォルトエンコーディングの Windows 上で発生する問題。
ptb.py の loader() でコケているので、 loader() 内の open() 関数でエンコーディングを指定してやる。

ptb.py
    def loader(path):
        vocab = {}
        with open(path, encoding='utf-8') as f:
           :

ptb.py 中にあるほかの open 関数も同様にエンコーディングを指定しておくこと。

KeyError

事象

今度は KeyError が生じた。

$ python ./train_ptb.py --epoch 1
Traceback (most recent call last):
  File "./train_ptb.py", line 204, in main
    data['train'], data['val'], data['test'] = chainer.datasets.get_ptb_words()
  File "C:\Program Files\Anaconda3\lib\site-packages\chainer\datasets\ptb.py", line 32, in get_ptb_words
    valid = _retrieve_ptb_words('valid.npz', _valid_url)
  File "C:\Program Files\Anaconda3\lib\site-packages\chainer\datasets\ptb.py", line 69, in _retrieve_ptb_words
    loaded = download.cache_or_load_file(path, creator, numpy.load)
  File "C:\Program Files\Anaconda3\lib\site-packages\chainer\dataset\download.py", line 156, in cache_or_load_file
    content = creator(temp_path)
  File "C:\Program Files\Anaconda3\lib\site-packages\chainer\datasets\ptb.py", line 62, in creator
    x[i] = vocab[word]
KeyError: '海浜'

解決策

(筆者の推測になるが) 学習データに含まれない単語は評価データやテストデータに含んではいけない。この仮説を検証するため、一時的に「学習データ+評価データ+テストデータ」を「学習データ」として学習させ、エラーが消えることを確認する。

# データセットを一度綺麗にして
$ rm C\:/Users/Tsubasa\ Ogawa/.chainer/dataset\\pfnet/chainer/ptb/* 
# 「インチキ学習データ」をキャッシュ化
$ cat ./ptb.*.txt > ./ptb.all.txt
$ cp -p ./ptb.all.txt C\:/Users/Tsubasa\ Ogawa/.chainer/dataset/_dl_cache/54a4921558c56e3d6f4a4d0b32b2c728

確認後は、学習データを充実させるなどして学習データに未知語が含まれないようにする。
NOTE: このインチキ学習データを用いると、クローズドテストになってしまい実運用時の精度が大幅に落ちてしまう。学習データとしては使わないこと。

gentxt 編

生成したモデルを用いて文章を生成するのが gentxt.py となる。

ValueError

事象

$ python ptb/gentxt.py --model result/model_iter_4829 --primetext 今日は
  :
  File "C:\Program Files\Anaconda3\lib\site-packages\chainer\serializers\npz.py", line 122, in __call__
    numpy.copyto(value, dataset)
ValueError: could not broadcast input array from shape (19281,650) into shape (21125,650)

value と dataset の次元数が異なるためにブロードキャストができないとのこと。どうもデータセットの作成にミスっていたみたい…

解決策 (2通り)

  1. データセットを作り直す。こちらの方が確実。
  2. npz.py を修正して無理やりエラーを抑制する。 当該コードは以下部分。
npz.py
    def __call__(self, key, value):
        :
        if value is None:
            return dataset
        elif isinstance(value, numpy.ndarray):
            numpy.copyto(value, dataset)
              :

力技で両者の次元数を揃える。そのために Numpy の resize を使う。

npz.pyの問題箇所を書き換えたもの
        elif isinstance(value, numpy.ndarray):
            numpy.copyto(numpy.resize(value, dataset.shape), dataset)
              :
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2