24
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

SeqGANによる文章生成を理解しようとしたメモ

Posted at

はじめに

大型連休にまとまった時間ができたということで、深層学習で遊んでみたくて文章生成の実験をしていました。

これだけでも面白い結果が出てきていて、お遊びレベルでは全然OKなのですが、文法的に破綻した文章ができるなど課題も多いです。
そこで、本格的に文章生成に特化した手法を試してみたいと思い、GANをベースにした手法の一つであるSeqGANの実装を調べ、改良してみました。(とか言ってモタモタしていたら連休終わってしまった)

検証環境

  • Google Colaboratory
    • TensorFlow 2.2.0-rc4
    • ランタイム: GPU

GAN?

深層学習の分野で提案された GAN (Generative Adversarial Networks: 敵対的生成ネットワーク) という手法があります。データから学習した特徴を使って、実在しないデータを生成するようなネットワークを実現する手法です。
[1406.2661] Generative Adversarial Networks - arXiv
「Generative Adversarial Nets」論文要約と、CNTKによるGANの実装紹介 - Qiita

アイドルの顔画像で学習して、実在しないアイドルの顔画像を無限に作る、みたいなのが有名な例でしょうか1
TensorFlowによるDCGANでアイドルの顔画像生成 - すぎゃーんメモ

GANを私の浅い理解でまとめてみます。
以下では便宜上「画像」と書いていますが、実際にはテキストだったり色々なものが有り得ると思います。

  • Generator, Discriminator という2種類のモデルがある。
    • Generator は画像を生成するモデル
    • Discriminator は画像が本物か偽物(Generatorが作ったもの)かを識別するモデル
  • 2種類のモデルを異なる損失関数(目的関数)で交互に学習する。
    • Generator は、Discriminator の識別結果を使って学習する(Discriminator が本物だと思うような画像を作るように学習する
    • Discriminator は、本物(与えられた学習データの画像)と偽物(Generatorが作った画像)を見破れるように学習する
  • 結果的に、Generator単体で学習するよりも自然な(本物らしい)画像を作れるようになる

SeqGAN

画像に対してはDCGANという画像生成向けに改良された手法がありますが、文章生成のタスクでもGANを発展させたSeqGANという手法が提案されています。
[1609.05473] SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient - arXiv

論文だけだと細部の実装が分かりにくいのですが、論文の著者がTensorFlow実装を公開しているので、そちらのコードと並行して読みながら内容を理解していきます。
LantaoYu/SeqGAN: Implementation of Sequence Generative Adversarial Nets with Policy Gradient

また、他の方もSeqGANについて解説しているので、そちらも参考にします。
SeqGANを用いてテキスト(小説のあらすじ)の生成をする - Qiita

手法自体は2016年に発表されたものなので、この分野としては新しい技術ではないのですが、今更ながら自分の勉強のためにまとめてみました。

モデルの構成

まずはGeneratorとDiscriminatorがどのように構成されているか見てみましょう。

Generator

Generatorはデータを生成する方のモデルです。今回は文章が題材なので、文章(単語列)を生成するものと思ってください。

コード (generator.py) を見ると、以下の図のように実装されているようです。単語IDを埋め込み行列によりベクトル表現に変換し、1単語入力するごとに1単語の出力を得ます。単に確率の高いものを選ぶのではなく、確率に従ってランダムに単語を選んでいます(論文には MC Search(モンテカルロ探索)と書かれていますね)。最初の単語 ($x_1$) には self.start_token (実装での変数名)という特別な値を入れます(図では <BOS> と書いています)。

image.png

この構図、どこかで見たことがあると思ったら、前回の記事で試したSeq2SeqモデルのDecoder側と同じでした。

ちなみにLSTMの内部状態の初期値 $h_0, c_0$ は、SeqGANの実装によると zeros でよいようです2
内部状態の初期値と1単語目の特別な記号を入れたら、あとはひたすらランダムに単語列を作ってくれるというわけですね。

Discriminator

論文と実装 (discriminator.py) を基に図示してみました。
image.png

入力列は本物と偽物の判定を行いたい単語列です。
最初に入力列をベクトル表現に変換するところは、Generatorと同じです。
ただし、GeneratorとDiscriminatorでは別個にベクトル表現を学習しています(実装を見ると埋め込み次元数が違っています)。

次に、このベクトル表現の列を畳み込み層に通します。この時、様々なサイズのフィルタを用意することで、異なる特徴量を得ています(フィルタサイズの種類数、および各フィルタサイズとフィルタ数(=出力次元数)はハイパーパラメータです)。その結果に対して、時刻方向(単語の並びの方向)にMaxPooling演算を行います。
入力列は一般に可変長ですが3、MaxPoolingで時刻方向の最大値を取ることになるので、ここで固定長の特徴量になっています。

その後に「To enhance the performance」ということで Highway4 という名前の演算を行っています。Highway の中身は本題から外れるので割愛します。

最後に、普通の識別問題と同じく全結合層(とSoftmax)を通し、本物か偽物(作り物)かの2値識別を行います。

こうしてみると、モデルの作り自体は意外と難しくないことが分かります。可変長の入力を受け付けるようにすると、畳み込みフィルタを通すところがちょっと実装しにくいかもしれませんが。

Pre-training

モデルを作ったら、次は学習方法を調べてみます。

いきなりGANとしてGeneratorとDiscriminatorを学習する前に、まずはGeneratorに対してPre-trainingを行っていきます。ノイズしか出さないGeneratorからいきなり学習を始めるより、**先にある程度まともな文章を作るGeneratorを学習しておいたほうがGANの学習が進みやすい(性能が良くなる)**ということのようですね。これは後述のように強化学習的なアプローチを使っていることに由来すると思われます。
とはいえ、あまり完璧に学習してしまうと、今度はDiscriminatorにとっては学習の手がかりがなくなってしまうので良くないらしいです(Discriminatorにとっての識別対象は、本物データとGeneratorの作り物データなので)。

Generator

まずはGeneratorだけに注目し、学習データと同じ文を生成する確率が高くなるようにGeneratorを学習します。ある程度学習できると、ランダムに単語を選んでいってもそれっぽい文章ができる確率が上がるはずです。以下の図の「正解ラベル」を学習データセットから1文ずつ持ってきて(実際にはミニバッチで学習させますが)、クロスエントロピー損失でモデルの重みを更新していきます。
image.png

図にも描いていますが、ここでの学習データの入力側としては、**1単語目を <BOS> とし、その後は出力側より1単語分遅れたものとします。**損失関数の設定より、確率100%で正解ラベル $w_1, w_2, ...$ と同じ出力になることを期待している、と解釈できるので、1単語遅れたものを入力側に指定することは納得できます5あれ、学習もSeq2SeqのDecoderと同じなのですね(内部状態の初期値の違いはありますが)。

なお、普通に考えると学習データセットはある決まった有限個の文章の集合になるわけですが、論文および実装ではLSTMベースの学習済み言語モデル (target_lstm.py, 論文では $G_{\rm oracle}$) からサンプリングしてきた単語列をPre-trainingに利用しているようです6。もっとも、学習データセットをどう作るかというのは手法自体の本筋からはやや外れるでしょう(と思っています)。

Discriminator

続いて、Discriminator側についても、最初にある程度学習しておきます。
このときの学習データは

  • 正例:学習済み言語モデルからサンプリングした文章(10000件)
  • 負例:Pre-trainingしたGeneratorが出力した文章(10000件)

となります。
合計20000件のデータをシャッフルしてミニバッチ化し、正例については出力が1に、負例については出力が0になるように、クロスエントロピー損失関数を使ってモデルの重みを更新します(普通の識別問題と同じですね)。

メインの学習

いよいよGANとしてのメインの学習フェーズ(敵対的学習)に入ります。
前述のように、GeneratorとDiscriminatorの学習を交互に実行します。

Generator

学習方法自体は Pre-train と同様ですが、Generator自身が作った文章を入力として学習している点が違います。実装を見ると、損失関数が self.rewards で重み付けされています。これは Policy Gradient と呼ばれる強化学習の手法になっています。
この rewardsROLLOUT.get_reward() で計算されており、**「Discriminatorが本物と判断する確率(0~1の実数、確信度といった方が良いかも?)」を持っています。別の言い方をすると「『自然な』文章度合い」**を表す値ともいえます。

  • 作った文章が『自然な』文章だったら、その文章がより出やすくなるように学習する
  • 作った文章が『不自然な』文章だったら、出てくる確率(スコア)が低くても気にしない

という重み付けを、rewards によって行っていることになります。
どのようなデータが『自然な』文章か、という正解をGeneratorに直接与えるのではなく、「Generatorに文章をいくつか作らせてみて、Discriminatorによる評価結果をGeneratorに与えて学習させる」という構図になっていて、まさに強化学習ですね。そのため、敵対的学習を始める前に、Pre-trainingによってある程度の割合で『自然な』文章が作れる状態にしておいたほうが良いようです。

この重み付けは Roll-out という手法により、論文の式(7)を計算することによってGeneratorの重みを更新します7。正直この式変形についてはよく分かっていませんが…。

Discriminator

Discriminator側は、正例(本物のデータ)と負例(今の時点でのGeneratorがランダムに生成した偽物データ)を使って学習を行います。学習方法は Pre-train の時と全く同じです。
実装を見ると、正負例を各10000件、合計20000件のデータで3エポック回したら、負例の10000件を再サンプルしてまた3エポック回し、合わせて5セットの負例を処理しているようです8。元論文の Figure 3 でこのエポック数などのパターンが何通りか試されていて、その中で良かったと書かれている回数が設定されているようです(データセットによるのでしょうが)。

文章生成

ここまで来れば、最初に書いたとおり、学習済みのGeneratorに1単語目の特別な記号を入れたら、あとはひたすら出力確率に従ってランダムに単語列を作っていくことにより文章ができていくわけですね。

実装

GitHubに論文の著者によるTensorFlow実装があり、TensorFlow 1.0.1, Python 2.7 で動作するコードになっています。
LantaoYu/SeqGAN: Implementation of Sequence Generative Adversarial Nets with Policy Gradient

ただ、コードを見ると各クラスの __init__() にそのクラスに関係するすべての演算が Tensor ベースで定義されていて、sess.run() で目的の演算を実行するようになっていることから、どこからどこまでが一つの処理に関係するコードなのか分かりにくい印象です。自分の理解のため、この点に加え、以下の点の改良を試みました。

  • TF 2.x, Python 3.x 向けのコードに移行した
  • 処理の単位を整理して関数(メソッド)に分けた
    • 必要な部分に @tf.function
  • Kerasベースの書き方に移行した
    • コードが短くなった
    • CuDNNによるLSTM実装が利用可能になり、Generatorの学習・推論が数倍高速になった
  • 各段階を終えたときに、途中のモデルの保存・読み込みを行うようにした
  • コードの冗長性を減らした(同一の処理をまとめた)

個人的にはCuDNN実装への移行が結構効いていて目に見えて速くなったので良いと思います。
それでも油断しているとColabのGPUランタイムから時間切れで締め出されてしまいますが…
また、全体のコードの量もオリジナルの半分以下になり、読みやすくなっていると思います。
クラスの構成はなるべく変えないように心がけました。

オリジナルのリポジトリをフォークして変更したものを置いておきました。よろしければどうぞ。
TensorFlow 2.2でないとうまく動かない(2.1.0だと Model.fit()sample_weight を与えるところでエラーになる)ようなのでご注意ください。
build1024/SeqGAN: Implementation of Sequence Generative Adversarial Nets with Policy Gradient for TensorFlow 2.2 and Python 3.x

まとめ

SeqGANの手法について自分で手を動かしつつ調べてみました。
ただ、単語IDを操作しているだけで実際の表記との対応が分からないので、Lossの値以外にうまく動いているかを調べる術がありません。
やはり自分でデータセットを準備して試したいところです。まずは青空文庫などで試してみますかね9

  1. リンク先のブログ記事も含め、実際にはDCGANという画像生成向けに改良された手法が使われていることが多い印象です。

  2. Seq2Seqモデルでは、Encoderが出力した内部状態をDecoderの内部状態の初期値とするようになっていました。

  3. ただし今回参照しているSeqGANの実装では入力列の長さが固定値 (20) となっています。

  4. [1505.00387] Highway Networks - arXiv

  5. 損失関数(クロスエントロピー損失)の値は、正解と同じラベルが確率1 (100%) で出るとき、かつそのときに限り0になります。

  6. 論文によると、有限の文章集合を用いるよりも、実世界の文章全体の集合をよりよく再現するはずだから(意訳)とのことです。また、Generatorの学習の進み具合を学習済み言語モデルにおけるスコア(Target Loss)で測っています。

  7. 実装を見ると、Roll-out で途中までの単語列から先を予測するモデルは、元のGeneratorとは別に重みが更新されていく (rollout.update_params()) ようなのですが、なぜそうしているのかはちょっと分かりません…(Generatorの重みパラメータの変化をスムージングしているように見えます)。

  8. せっかく実世界の文章集合を言語モデルとして表現したのに、実装を見る限り、学習全体を通して正例の10000件は使い回されているっぽいのが気になります。個人的には正例も定期的にサンプリングし直したほうがよいと思うのですが、大丈夫なのですかね…?

  9. ゆくゆくは「死後さばきにあう」系の文章を作ってみたいのですが、Target LSTMを作るには圧倒的に学習データが足りないため、その対策を考える必要があります。

24
14
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?