3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

tensorflow2でCRFを使った学習

Posted at

はじめに

tensorflow1では、keras-contribのcrf_lossとcrf_accuracyを使うことで、簡単に学習が可能でした。
tensorflow2以降、CRFを使った学習が分かりづらくなりました。
日本語での解説記事はほとんど見つかりませんでした。
tensorflow2では、lossとaccuracyを定義したモデルを自分で作る必要があるようです。

環境

tensorflow 2.4.1
tensorflow-addons 0.12.1

モデル

model.py
import tensorflow as tf
from tensorflow.keras import layers as L
from tensorflow_addons.layers.crf import CRF
from tensorflow.python.keras.engine import data_adapter
from tensorflow_addons.text.crf import crf_log_likelihood

class ModelWithCRFLoss(tf.keras.Model):
    def __init__(self, output_names, *args, **kwargs):
        self.__output_names = output_names
        super().__init__( *args, **kwargs)

        self._accuracy_func = [tf.keras.metrics.Accuracy(name=name + "_accuracy") for name in self.__output_names]
        
    def crf_loss(self, y_true, y_pred):
        viterbi_sequence, potentials, sequence_length, chain_kernel = y_pred
        loss = -crf_log_likelihood(potentials, y_true, sequence_length, chain_kernel)[0] / tf.cast(sequence_length, tf.float32)
        return viterbi_sequence, sequence_length, tf.reduce_mean(loss)
    
    def compute_loss(self, x, y, sample_weight, training=False):
        y_preds = self(x, training=training)

        total_loss = 0.0
        results = []
        for y_true, pred in zip(y, y_preds):
            viterbi_sequence, sequence_length, loss = self.crf_loss(y_true, pred)
            total_loss += loss
            results.append([viterbi_sequence, sequence_length, loss])

        return results, total_loss

    def train_step(self, data):
        data = data_adapter.expand_1d(data)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            results, total_loss = self.compute_loss(x, y, sample_weight, training=True)

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        m = {"loss": total_loss}

        for i in range(len(results)):
            m[self.__output_names[i] + "_loss"] = results[i][-1]

        for i in range(len(results)):
            self._accuracy_func[i].update_state(y[i], results[i][0])
            m[self._accuracy_func[i].name] = self._accuracy_func[i].result()

        return m

    def test_step(self, data):
        data = data_adapter.expand_1d(data)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            results, total_loss = self.compute_loss(x, y, sample_weight, training=False)

        m = {"loss": total_loss}

        for i in range(len(results)):
            m[self.__output_names[i] + "_loss"] = results[i][-1]

        for i in range(len(results)):
            self._accuracy_func[i].update_state(y[i], results[i][0])
            m[self._accuracy_func[i].name] = self._accuracy_func[i].result()
            
        return m

def create_model():
    input_layer = L.Input(shape=(None, 256))
    x = L.Dense(256)(input_layer)
    x = L.LSTM(256, return_sequences=True)(x)
    x = L.Dense(256)(x)
    x = CRF(13, name="out")(x)

    # output_namesには、出力層の名前を指定します。
    # 出力が複数のレイヤーの場合は、順序を間違えないようにしてください。
    model = ModelWithCRFLoss(inputs=[input_layer], outputs=[x], output_names=["out"])
    return model
Model: "model_with_crf_loss"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, None, 256)]       0
_________________________________________________________________
dense (Dense)                (None, None, 256)         65792
_________________________________________________________________
lstm (LSTM)                  (None, None, 256)         525312
_________________________________________________________________
dense_1 (Dense)              (None, None, 256)         65792
_________________________________________________________________
out (CRF)                    [(None, None), (None, Non 3536
=================================================================
Total params: 660,434
Trainable params: 660,432
Non-trainable params: 2
_________________________________________________________________

使い方は、tf.keras.ModelをModelWithCRFLossに置き換えるだけです。
出力は2次元(batch、time_steps)になることに注意してください。
one-hotベクトルの (batch, time_steps, classes) は利用できません。

訓練

訓練は、通常のモデルと同じように行えます。
lossとaccuracyはモデルで定義されているので、compileではオプティマイザを指定するだけです。

train.py
from model import create_model
from tensorflow_addons.optimizers import RAdam
model = create_model()
optimizer = RAdam()
# optimizerだけを指定
model.compile(optimizer)

おわりに

tensorflow2でCRFを使ったトレーニングができるようになりました。
ですが、keras-contribのCRFに比べて倍近く速度が遅いです。
早く改善されることを期待してます。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?