LoginSignup
5
5

More than 3 years have passed since last update.

Beam Searchの実装

Posted at

 本稿では、前回投稿で作成したTransformerの応答文生成処理に、Beam Searchを追加して、精度がどれくらい改善するかを定性的に確認します。

1. はじめに

 Transformer実装にあたって参考にした原論文「Attention Is All You Need」には、Beam Searchに関する記述があったのですが、Beam Searchを実装していないLSTMとの比較条件をそろえるため、前回投稿「Kerasで実装するTransformer」では実装を見送りました。

 今回は改めてBeam Searchを実装し、応答文生成の精度がどれくらい改善するか、見てみることにします。

2. 本稿のゴール

 以下の通りです。

  • Beam Searchの実装
  • Beam Search効果の定性評価

 なお、実行環境や訓練データなどはすべて、「Kerasで実装するTransformer」と同じです。

3. Beam Searchの説明

 自然言語処理など、逐次的にアウトプットを予測していく処理では、普通は時刻ごとに出力を予測します。

 一方、Beam Searchでは、各時刻において、その時刻までの系列全体の予測確率の大きいほうからbeam_size個分抽出する、ということを繰り返し、最後まで予測し終わった後で、一番確率が大きい系列を正解とします。

 具体的には、予測確率の対数を取った対数尤度を算出して、出力系列を決定します。方式は、「Attention Is All You Need」が引用している、ちょっと長いタイトルの論文「Google’s Neural Machine Translation System: Bridging the Gap
between Human and Machine Translation
」を参考にしました。

 対数尤度は、時系列方向に帰納的に決定していきます。出力次元数を$N$、時刻${i}$における出力ベクトルを${o_i=(o_{ij})_{j \in [0,N-1]}}$とします。

 各${o_{ij}}$はSoftmaxの出力結果なので、1以下の正の実数です。

 このとき、時刻$i$ および出力次元${k \in [0,}$ beam_size ${-1]}$における対数尤度$s_{ik}$を、以下のように定義します。

$i=0$のとき

 集合 { ${log(o_{0j})\ |\ 0\leqq j \leqq N-1}$} の要素のうち、大きいほうからbeam_size個の要素を、$s_{0k}$ , ${k \in [0,}$ beam_size ${-1]}$と定義します。

$i\geqq1$のとき

 集合 { ${log(o_{ij})/((6+i)/6)^{\alpha}+s_{i-1\ l}\ |\ 0\leqq j \leqq N-1\ ,\ 0\leqq l \leqq}$ beam_size ${-1}$} の要素のうち、大きいほうからbeam_size個の要素を、$s_{ik}$ , ${k \in [0,}$ beam_size ${-1]}$と定義します。

 ここに$((6+i)/6)^{\alpha}$は、Beam Searchにおいて長い系列が不利になるのを緩和するための補正項で、$\alpha$は1未満の正の実数です。$\alpha$の実際の値は、「Attention Is All You Need」に倣って、0.6にしてあります。

 実は、引用元論文「Google’s Neural Machine Translation System: Bridging the Gap
between Human and Machine Translation
」では、対数尤度にattentionの係数を足していましたが、実装が大変そうなのと、「Attention Is All You Need」では言及がなかったので、その実装は今回は割愛しました。

4. ソースコード

 前回投稿の「4−2. 応答文生成処理(2000_response.ipynb)」の「cell[5] 応答文組み立て」を、以下のソースコードに置き換えます。

\

クリックして表示
2000_response.ipynb cell[5] 応答文組み立て

#*******************************************************************************
#                                                                              *
#   応答文組み立て                                                             *
#                                                                              *
#*******************************************************************************

def generate_response(e_i, data) :
    maxlen_e      = data['maxlen_e']
    maxlen_d      = data['maxlen_d']
    n_hidden      = data['n_hidden']
    output_dim    = data['output_dim']
    indices_word  = data['indices_word']
    word_indices  = data['word_indices']
    words         = data['words']
    encoder_model = data['encoder_model']
    decoder_model = data['decoder_model']

    beam_size = 4
    alpha = 0.6

    # Encode the input as state vectors.
    # エンコーダインプットをbeam_size分、行コピーする
    e_input = np.array([e_i] * beam_size).reshape((beam_size, maxlen_e))
    encoder_results = encoder_model.predict(e_input)
    encoder_outputs = encoder_results[:-1]
    encoder_i_mask = encoder_results[-1]
    decoded_sentence = ''
    target_seq = np.zeros((1, maxlen_d+1) ,dtype='int32')
    target_seq[0,  0] = word_indices['SSSS']
    target_seq = np.array([target_seq] * beam_size)
    target_seq = target_seq.reshape((beam_size, maxlen_d+1))
    target_data = [[target_seq[i], 0.0, 0] for i in range(beam_size)]

    # 応答文字予測
    for i in range(0,maxlen_d) :
        target_seq = np.array([target_data[j][0] for j in range(beam_size)])
        target_seq = target_seq.reshape((beam_size, maxlen_d+1))
        target_seq = target_seq[:, :-1]
        decoder_output = decoder_model.predict(encoder_outputs
                                               +[target_seq, encoder_i_mask])

        candidates = []
        for j in range(beam_size) : 
            end_flg = target_data[j][2]
            if end_flg == 0 :
                output_vec = decoder_output[j, i, :]
                sorted_index = np.argsort(output_vec)[::-1]
                for k in range(beam_size) :
                    # 予測単語のインデックスを求める
                    sampled_token_index = sorted_index[k]
                    # 対数尤度算出
                    log_likelihood = math.log(output_vec[sampled_token_index])
                    log_likelihood /= math.pow((i+6)/6, alpha) 
                    log_likelihood += target_data[j][1]
                    # 予測単語インデックス列更新
                    target_seq_row = copy.copy(target_data[j][0])
                    sampled_char = indices_word[sampled_token_index]
                    if sampled_char == 'SSSS' :
                        end_flg = 1
                    else :
                        target_seq_row[i+1] = sampled_token_index
                    candidates.append([target_seq_row, 
                                       log_likelihood, end_flg])
            else :
                candidates.append([copy.copy(target_data[j][0]), 
                                   target_data[j][1], end_flg])
        # Update the target sequence 
        if i == 0 :
            index_order = np.arange(beam_size)
        else :
            log_likelihood_list = [candidates[j][1] 
                                   for j in range(len(candidates))]
            index_order = np.argsort(log_likelihood_list)[::-1]
        target_data = [candidates[index_order[j]] for j in range(beam_size)]
        end_flg_list = [target_data[j][2] for j in range(beam_size)]
        if min(end_flg_list) > 0 :
            break

    log_likelihood_mat = np.array([target_data[j][1] for j in range(beam_size)])
    max_index = np.argmax(log_likelihood_mat)
    most_likely_seq = target_data[max_index][0]
    index_list = np.nonzero(most_likely_seq)[0]
    decoded_sentence = ''
    for i in range(1,len(index_list)) :
        decoded_sentence += indices_word[most_likely_seq[index_list[i]]]

    return decoded_sentence   

5. 応答文生成

 前回投稿と同じ要領で、応答文を生成します。その結果は以下の通りです。前回の結果と比較してみます。

before
>> おはよう!
はい?
>> 今何してる?
うん。
>> ご飯食べた?
うん。
>> こんにちは。
ん?
>> それでは御免蒙りまするでござります。
ヘエ/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\してるから/\してるんですか。
after
>> おはよう!
ああ!
>> 今何してる?
うん。
>> ご飯食べた?
うん。
>> こんにちは。
ん?
>> それでは御免蒙りまするでござります。
ヘエ

 前回と比べると、「おはよう!」の返事が、なんか元気になったのと、不自然に長い応答が無くなったところが違っています。

 前回試してみた、日本文学シリーズについても比較してみます。

before
>> 吾輩は猫である。
ふーん。
>> 国境の長いトンネルを抜けると雪国だった。
的にはなかったからになってるんですからということをするんですからね。
after
>> 吾輩は猫である。
ふーん。
>> 国境の長いトンネルを抜けると雪国だった。
的には!

 こちらも、長い応答が改善されています。

6. おわりに

 以上が、Beam Searchを実装してみた結果です。応答文の生成は、確かに改善されています。今後、大規模なコーパスでTransformerを試してみる際にも、その効果が期待できそうです。

変更履歴

項番 日付 変更箇所 内容
1 2021/2/19 - 初版
5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5