0
0

More than 1 year has passed since last update.

Tensorflow の seq2seq-attention 言語モデルの学習効率を上げる方法

Last updated at Posted at 2022-12-05

 seq2seq-attention の言語モデルをページ

で勉強しました。このページに掲載されているプログラムを

の Japanese-English データに適用しました。その学習状況は

Epoch 1 Batch 0 Loss 1.2707
Epoch 1 Batch 100 Loss 0.7254
Epoch 1 Batch 200 Loss 0.7501
Epoch 1 Batch 300 Loss 0.7255
Epoch 1 Batch 400 Loss 0.6655
Epoch 1 Batch 500 Loss 0.5600
Epoch 1 Batch 600 Loss 0.5565
Epoch 1 Batch 700 Loss 0.5139
Epoch 1 Batch 800 Loss 0.5070
Epoch 1 Batch 900 Loss 0.4456
Epoch 1 Batch 1000 Loss 0.5286
Epoch 1 Loss 0.6024
Time taken for 1 epoch 11359.96979188919 sec

Epoch 2 Batch 0 Loss 0.4378
Epoch 2 Batch 100 Loss 0.3911
Epoch 2 Batch 200 Loss 0.4823
Epoch 2 Batch 300 Loss 0.4193
Epoch 2 Batch 400 Loss 0.4181
Epoch 2 Batch 500 Loss 0.4416
Epoch 2 Batch 600 Loss 0.4175
Epoch 2 Batch 700 Loss 0.4000
Epoch 2 Batch 800 Loss 0.3676
Epoch 2 Batch 900 Loss 0.3445
Epoch 2 Batch 1000 Loss 0.3164
Epoch 2 Loss 0.4078
Time taken for 1 epoch 11271.617814302444 sec

Epoch 3 Batch 0 Loss 0.3195
Epoch 3 Batch 100 Loss 0.2877
Epoch 3 Batch 200 Loss 0.3002
Epoch 3 Batch 300 Loss 0.2978
Epoch 3 Batch 400 Loss 0.2876
Epoch 3 Batch 500 Loss 0.2604
Epoch 3 Batch 600 Loss 0.2843
Epoch 3 Batch 700 Loss 0.3191
Epoch 3 Batch 800 Loss 0.2905
Epoch 3 Batch 900 Loss 0.2513
Epoch 3 Batch 1000 Loss 0.2656
Epoch 3 Loss 0.2893
Time taken for 1 epoch 11470.152668476105 sec

Epoch 4 Batch 0 Loss 0.2208
Epoch 4 Batch 100 Loss 0.2324
Epoch 4 Batch 200 Loss 0.1996
Epoch 4 Batch 300 Loss 0.2141
Epoch 4 Batch 400 Loss 0.2059
Epoch 4 Batch 500 Loss 0.2843
Epoch 4 Batch 600 Loss 0.1781
Epoch 4 Batch 700 Loss 0.2124
Epoch 4 Batch 800 Loss 0.2325
Epoch 4 Batch 900 Loss 0.1971
Epoch 4 Batch 1000 Loss 0.1857
Epoch 4 Loss 0.2090
Time taken for 1 epoch 11660.267095804214 sec

Epoch 5 Batch 0 Loss 0.1688
Epoch 5 Batch 100 Loss 0.1405
Epoch 5 Batch 200 Loss 0.1346
Epoch 5 Batch 300 Loss 0.1591
Epoch 5 Batch 400 Loss 0.1453
Epoch 5 Batch 500 Loss 0.1513
Epoch 5 Batch 600 Loss 0.1324
Epoch 5 Batch 700 Loss 0.1507
Epoch 5 Batch 800 Loss 0.1376
Epoch 5 Batch 900 Loss 0.1625
Epoch 5 Batch 1000 Loss 0.1504
Epoch 5 Loss 0.1533
Time taken for 1 epoch 11892.76987528801 sec

Epoch 6 Batch 0 Loss 0.1264
Epoch 6 Batch 100 Loss 0.0922
Epoch 6 Batch 200 Loss 0.0913
Epoch 6 Batch 300 Loss 0.1104
Epoch 6 Batch 400 Loss 0.0887
Epoch 6 Batch 500 Loss 0.1095
Epoch 6 Batch 600 Loss 0.1110
Epoch 6 Batch 700 Loss 0.1567
Epoch 6 Batch 800 Loss 0.1041
Epoch 6 Batch 900 Loss 0.1109
Epoch 6 Batch 1000 Loss 0.1477
Epoch 6 Loss 0.1138
Time taken for 1 epoch 12085.702548027039 sec

Epoch 7 Batch 0 Loss 0.0726
Epoch 7 Batch 100 Loss 0.0933
Epoch 7 Batch 200 Loss 0.0886
Epoch 7 Batch 300 Loss 0.0885
Epoch 7 Batch 400 Loss 0.0912
Epoch 7 Batch 500 Loss 0.0827
Epoch 7 Batch 600 Loss 0.0957
Epoch 7 Batch 700 Loss 0.0924
Epoch 7 Batch 800 Loss 0.0847
Epoch 7 Batch 900 Loss 0.0970
Epoch 7 Batch 1000 Loss 0.0813
Epoch 7 Loss 0.0870
Time taken for 1 epoch 12278.311335802078 sec

Epoch 8 Batch 0 Loss 0.0633
Epoch 8 Batch 100 Loss 0.0685
Epoch 8 Batch 200 Loss 0.0509
Epoch 8 Batch 300 Loss 0.0557
Epoch 8 Batch 400 Loss 0.0567
Epoch 8 Batch 500 Loss 0.0624
Epoch 8 Batch 600 Loss 0.0741
Epoch 8 Batch 700 Loss 0.0553
Epoch 8 Batch 800 Loss 0.0784
Epoch 8 Batch 900 Loss 0.0777
Epoch 8 Batch 1000 Loss 0.0709
Epoch 8 Loss 0.0670
Time taken for 1 epoch 12547.648915529251 sec

Epoch 9 Batch 0 Loss 0.0479
Epoch 9 Batch 100 Loss 0.0461
Epoch 9 Batch 200 Loss 0.0426
Epoch 9 Batch 300 Loss 0.0526
Epoch 9 Batch 400 Loss 0.0578
Epoch 9 Batch 500 Loss 0.0501
Epoch 9 Batch 600 Loss 0.0518
Epoch 9 Batch 700 Loss 0.0460
Epoch 9 Batch 800 Loss 0.0496
Epoch 9 Batch 900 Loss 0.0529
Epoch 9 Batch 1000 Loss 0.0593
Epoch 9 Loss 0.0536
Time taken for 1 epoch 14471.245846509933 sec

Epoch 10 Batch 0 Loss 0.0400
Epoch 10 Batch 100 Loss 0.0388
Epoch 10 Batch 200 Loss 0.0328
Epoch 10 Batch 300 Loss 0.0463
Epoch 10 Batch 400 Loss 0.0329
Epoch 10 Batch 500 Loss 0.0444
Epoch 10 Batch 600 Loss 0.0488
Epoch 10 Batch 700 Loss 0.0461
Epoch 10 Batch 800 Loss 0.0549
Epoch 10 Batch 900 Loss 0.0430
Epoch 10 Batch 1000 Loss 0.0458
Epoch 10 Loss 0.0443
Time taken for 1 epoch 15817.620344638824 s

のようでした。このプログラムについて、デコーダーの GRU に隠れ状態を入力し、適用する mask を損失関数での mask から、attention への mask へ変更して学習を行ったら、次のような学習データを得ました。

Epoch 1 Batch 0 Loss 9.5039
Epoch 1 Batch 100 Loss 0.7598
Epoch 1 Batch 200 Loss 0.7395
Epoch 1 Batch 300 Loss 0.7009
Epoch 1 Batch 400 Loss 0.6608
Epoch 1 Batch 500 Loss 0.6924
Epoch 1 Batch 600 Loss 0.6121
Epoch 1 Batch 700 Loss 0.5635
Epoch 1 Batch 800 Loss 0.5209
Epoch 1 Batch 900 Loss 0.4997
Epoch 1 Batch 1000 Loss 0.4910
Epoch 1 Loss 0.6741
Time taken for 1 epoch 33322.90072131157 sec

Epoch 2 Batch 0 Loss 0.4410
Epoch 2 Batch 100 Loss 0.5025
Epoch 2 Batch 200 Loss 0.4676
Epoch 2 Batch 300 Loss 0.3934
Epoch 2 Batch 400 Loss 0.4159
Epoch 2 Batch 500 Loss 0.3864
Epoch 2 Batch 600 Loss 0.3708
Epoch 2 Batch 700 Loss 0.3715
Epoch 2 Batch 800 Loss 0.3738
Epoch 2 Batch 900 Loss 0.3800
Epoch 2 Batch 1000 Loss 0.3732
Epoch 2 Loss 0.4045
Time taken for 1 epoch 32844.25040769577 sec

Epoch 3 Batch 0 Loss 0.2990
Epoch 3 Batch 100 Loss 0.3201
Epoch 3 Batch 200 Loss 0.2654
Epoch 3 Batch 300 Loss 0.2817
Epoch 3 Batch 400 Loss 0.2544
Epoch 3 Batch 500 Loss 0.2355
Epoch 3 Batch 600 Loss 0.2483
Epoch 3 Batch 700 Loss 0.2741
Epoch 3 Batch 800 Loss 0.2979
Epoch 3 Batch 900 Loss 0.2217
Epoch 3 Batch 1000 Loss 0.2011
Epoch 3 Loss 0.2610
Time taken for 1 epoch 21008.75523161888 sec

Epoch 4 Batch 0 Loss 0.1692
Epoch 4 Batch 100 Loss 0.1473
Epoch 4 Batch 200 Loss 0.1394
Epoch 4 Batch 300 Loss 0.1633
Epoch 4 Batch 400 Loss 0.1938
Epoch 4 Batch 500 Loss 0.1591
Epoch 4 Batch 600 Loss 0.1727
Epoch 4 Batch 700 Loss 0.2025
Epoch 4 Batch 800 Loss 0.2221
Epoch 4 Batch 900 Loss 0.2051
Epoch 4 Batch 1000 Loss 0.1764
Epoch 4 Loss 0.1743
Time taken for 1 epoch 19569.03325510025 sec

Epoch 5 Batch 0 Loss 0.0951
Epoch 5 Batch 100 Loss 0.1011
Epoch 5 Batch 200 Loss 0.1123
Epoch 5 Batch 300 Loss 0.1192
Epoch 5 Batch 400 Loss 0.1094
Epoch 5 Batch 500 Loss 0.1174
Epoch 5 Batch 600 Loss 0.1163
Epoch 5 Batch 700 Loss 0.1325
Epoch 5 Batch 800 Loss 0.1400
Epoch 5 Batch 900 Loss 0.1339
Epoch 5 Batch 1000 Loss 0.1125
Epoch 5 Loss 0.1228
Time taken for 1 epoch 19797.13765978813 sec

Epoch 6 Batch 0 Loss 0.0875
Epoch 6 Batch 100 Loss 0.0693
Epoch 6 Batch 200 Loss 0.0780
Epoch 6 Batch 300 Loss 0.0973
Epoch 6 Batch 400 Loss 0.0801
Epoch 6 Batch 500 Loss 0.0987
Epoch 6 Batch 600 Loss 0.0921
Epoch 6 Batch 700 Loss 0.0776
Epoch 6 Batch 800 Loss 0.0880
Epoch 6 Batch 900 Loss 0.1114
Epoch 6 Batch 1000 Loss 0.1007
Epoch 6 Loss 0.0914
Time taken for 1 epoch 19926.429337263107 sec

Epoch 7 Batch 0 Loss 0.0728
Epoch 7 Batch 100 Loss 0.0823
Epoch 7 Batch 200 Loss 0.0622
Epoch 7 Batch 300 Loss 0.0604
Epoch 7 Batch 400 Loss 0.0730
Epoch 7 Batch 500 Loss 0.0674
Epoch 7 Batch 600 Loss 0.0769
Epoch 7 Batch 700 Loss 0.0677
Epoch 7 Batch 800 Loss 0.0503
Epoch 7 Batch 900 Loss 0.0735
Epoch 7 Batch 1000 Loss 0.0622
Epoch 7 Loss 0.0699
Time taken for 1 epoch 20041.852863788605 sec

Epoch 8 Batch 0 Loss 0.0383
Epoch 8 Batch 100 Loss 0.0448
Epoch 8 Batch 200 Loss 0.0629
Epoch 8 Batch 300 Loss 0.0531
Epoch 8 Batch 400 Loss 0.0412
Epoch 8 Batch 500 Loss 0.0674
Epoch 8 Batch 600 Loss 0.0534
Epoch 8 Batch 700 Loss 0.0670
Epoch 8 Batch 800 Loss 0.0693
Epoch 8 Batch 900 Loss 0.0540
Epoch 8 Batch 1000 Loss 0.0576
Epoch 8 Loss 0.0558
Time taken for 1 epoch 20089.24049091339 sec

Epoch 9 Batch 0 Loss 0.0436
Epoch 9 Batch 100 Loss 0.0366
Epoch 9 Batch 200 Loss 0.0296
Epoch 9 Batch 300 Loss 0.0284
Epoch 9 Batch 400 Loss 0.0308
Epoch 9 Batch 500 Loss 0.0444
Epoch 9 Batch 600 Loss 0.0549
Epoch 9 Batch 700 Loss 0.0395
Epoch 9 Batch 800 Loss 0.0441
Epoch 9 Batch 900 Loss 0.0529
Epoch 9 Batch 1000 Loss 0.0538
Epoch 9 Loss 0.0462
Time taken for 1 epoch 20214.765110969543 sec

Epoch 10 Batch 0 Loss 0.0265
Epoch 10 Batch 100 Loss 0.0422
Epoch 10 Batch 200 Loss 0.0410
Epoch 10 Batch 300 Loss 0.0327
Epoch 10 Batch 400 Loss 0.0417
Epoch 10 Batch 500 Loss 0.0452
Epoch 10 Batch 600 Loss 0.0412
Epoch 10 Batch 700 Loss 0.0493
Epoch 10 Batch 800 Loss 0.0350
Epoch 10 Batch 900 Loss 0.0406
Epoch 10 Batch 1000 Loss 0.0491
Epoch 10 Loss 0.0405
Time taken for 1 epoch 20369.89773750305 sec

 10 epochs 学習させた結果の損失関数が、0.0443 から 0.0405 へと一割向上しました。1 epoch にかかる学習時間が、およそ、16000 sec から、20000 sec に増えていますが、これは学習を行ったマシンの性能の差によるものと考えられます。二つ目の学習状況の最初の方の時間データが 30000 sec になっているのは、他の学習と同時に行ったからです。

適用したプログラムの修正は、デコーダーの GRU 関連( mask関連含む )について、

class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)

    # アテンションのため
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output, mask):
    # enc_output の shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output, mask)

    # 埋め込み層を通過したあとの x の shape  == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # 結合後の x の shape == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # 結合したベクトルを GRU 層に渡す
    output, state = self.gru(x, initial_state = hidden)                   #POINT

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

mask について。mask 作成関数を追加。

def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

    # add extra dimensions to add the padding
    # to the attention logits.
    return seq  # (batch_size, seq_len)

損失関数の mask。

optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  #mask = tf.math.logical_not(tf.math.equal(real, 0))  #POINT
  loss_ = loss_object(real, pred)

  #mask = tf.cast(mask, dtype=loss_.dtype)             #POINT
  #loss_ *= mask                                       #POINT

  return tf.reduce_mean(loss_)

attention 関連

class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values, mask):                                                  #POINT
    # hidden shape == (batch_size, hidden size)
    # hidden_with_time_axis shape == (batch_size, 1, hidden size)
    # スコアを計算するためにこのように加算を実行する
    hidden_with_time_axis = tf.expand_dims(query, 1)

    # score shape == (batch_size, max_length, 1)
    # スコアを self.V に適用するために最後の軸は 1 となる
    # self.V に適用する前のテンソルの shape は  (batch_size, max_length, units)
    score = self.V(tf.nn.tanh(
        self.W1(values) + self.W2(hidden_with_time_axis)))
    # input の 0 (pad埋めされた)の部分のattention_weights が 0 になるように。
    score_2dim = tf.squeeze( score, axis = 2 )                                          #POINT
    score_2dim  += (mask * -1e9)                                                        #POINT
    score = tf.expand_dims( score_2dim, axis = 2 )                                      #POINT

    #print( "shape of score:{}".format( score.shape ))
    # attention_weights の shape == (batch_size, max_length, 1)
    attention_weights = tf.nn.softmax(score, axis=1)
    #print( "shape of attention_weights:{}".format( attention_weights.shape))
    # context_vector の合計後の shape == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

attention のインスタンス作成

attention_layer = BahdanauAttention(10)
mask = tf.random.uniform( (sample_output.shape[0], sample_output.shape[1]))                               #POINT
attention_result, attention_weights = attention_layer(sample_hidden, sample_output, mask)                 #POINT
 
print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))

Decoder クラス

class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)

    # アテンションのため
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output, mask):                                         #POINT
    # enc_output の shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output, mask)       #POINT

    # 埋め込み層を通過したあとの x の shape  == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # 結合後の x の shape == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # 結合したベクトルを GRU 層に渡す
    output, state = self.gru(x, initial_state = hidden )

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

decoder インスタンスの作成

decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)
mask = tf.random.uniform( (sample_output.shape[0], sample_output.shape[1]))                               #POINT
sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)),
                                      sample_hidden, sample_output, mask)                                 #POINT
print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))

decoder の呼び出し train_step 関数。

@tf.function
def train_step(inp, targ, enc_hidden):
  loss = 0
  
  mask = create_padding_mask( inp )                                                     #POINT
         
  with tf.GradientTape() as tape:
    enc_output, enc_hidden = encoder(inp, enc_hidden)
 
    dec_hidden = enc_hidden
 
    dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)      
 
    # Teacher forcing - feeding the target as the next input
    for t in range(1, targ.shape[1]):
      # passing enc_output to the decoder
      predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output, mask)     #POINT
 
      loss += loss_function(targ[:, t], predictions)
 
      # using teacher forcing
      dec_input = tf.expand_dims(targ[:, t], 1)
 
  batch_loss = (loss / int(targ.shape[1]))
 
  variables = encoder.trainable_variables + decoder.trainable_variables
 
  gradients = tape.gradient(loss, variables)
 
  optimizer.apply_gradients(zip(gradients, variables))
   
  return batch_loss

decoder の呼び出し evaluate 関数。

def evaluate(sentence):
    attention_plot = np.zeros((max_length_targ, max_length_inp))
     
    sentence = preprocess_sentence(sentence)
 
    inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                           maxlen=max_length_inp,
                                                           padding='post')
    inputs = tf.convert_to_tensor(inputs)

    mask = create_padding_mask( inputs )                                                     #POINT

    result = ''
 
    hidden = [tf.zeros((1, units))]
    enc_out, enc_hidden = encoder(inputs, hidden)
 
    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)
     
    for t in range(max_length_targ):
        predictions, dec_hidden, attention_weights = decoder(dec_input,
                                                             dec_hidden,
                                                             enc_out, mask)                   #POINT
         
        # storing the attention weights to plot later on
        attention_weights = tf.reshape(attention_weights, (-1, ))
        attention_plot[t] = attention_weights.numpy()
 
        predicted_id = tf.argmax(predictions[0]).numpy()
 
        result += targ_lang.index_word[predicted_id] + ' '
 
        if targ_lang.index_word[predicted_id] == '<end>':
            return result, sentence, attention_plot
         
        # the predicted ID is fed back into the model
        dec_input = tf.expand_dims([predicted_id], 0)
 
    return result, sentence, attention_plot

でうまくいくと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0