Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?
@halhorn

tf.contrib.seq2seq.BeamSearchDecoder の length_penalty_weight

More than 3 years have passed since last update.

BeamSearch のスコアリングと length_penalty_weight

BeamSearch では、パラレルに beam_width 個の生成を各ステップで行い、そのうち最も「スコア」の高いもののみを残して次のステップを生成します。(大雑把な説明です)
ある一定の長さの文章を常に生成するのであれば、スコアとして log probability を使うことができます。(これが一番基本的な BeamSearch です。)
しかし、 Seq2Seq 等の場合生成される文章の長さには差があります。

  • A: やあ
    • logprob(やあ) = logprob(や) + logprob(あ)
  • B: こんにちは、いい天気ですね。
    • logprob(こんにちは・・・) = logprob(こ) + logprob(ん) + ... + logprob(。)

ここで logprob はマイナスの値なので、文章の長い B は小さな値に成ってしまい不利です。
その為 BeamSearch の結果短い文章ばかりが生成されることがあります。
この問題を補正するために、長さが長いほど、 log probability を大きくする正則化が必要になります。
これが length_penalty_weight です。

length_penalty_weight の計算

TensorFlow の公式の tf.contrib.seq2seq.BeamSearchDecoder では、上述の length_penalty_weight でこの正則化を行うことができます。

length_penalty_weight が 0.0 (デフォルト) だと長さでの正規化を行わず、1.0に近づくほど、長い文章の生成を優遇するようになります。(一応数式上は1.0を超えても動くはず)

コードを読むと、以下の計算式でスコアが計算されていることがわかります。

score=\frac{logprob}{length\_penalty} \\
length\_penalty=\left(\frac{5 + length}{5 + 1}\right)^\alpha

上式の $\alpha$ が length_penalty_weight です。
この式は https://arxiv.org/abs/1609.08144 の論文に書かれた方法の実装と成っています。
ちなみにこの論文ではこの数式や5というマジックナンバーの理由は特に述べられてません。経験的にこれがうまく行ったそうです。
また論文中ではこの文章の長さによる正則化の他に、 seq2seq のエンコーダーに渡す文章の各トークン(単語)をどれだけ網羅的に使えているかを attention mechanism を使って判断し、それも正則化項として加えているようです。

ちなみに論文の p.12 7 Decoder に上記の式の説明があります。

コード内での説明

lengthpenalty の tensorflow のコード内のコメントを引用しておきます。

Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
Returns the length penalty tensor:

  [(5+sequence_lengths)/6]**penalty_factor

where all operations are performed element-wise.
Args:
sequence_lengths: Tensor, the sequence lengths of each hypotheses.
penalty_factor: A scalar that weights the length penalty.
Returns:
If the penalty is 0, returns the scalar 1.0. Otherwise returns
the length penalty factor, a tensor with the same shape as
sequence_lengths.

1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
halhorn
DeepLearning で対話ロボットを作ろうとしているインコです。 https://www.wantedly.com/projects/92981
mixi
全ての人に心地よいつながりを

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
1
Help us understand the problem. What is going on with this article?