6
4

More than 3 years have passed since last update.

BERTのファインチューニングで威力を発揮する「slanted triangular learning rate」をKerasで実装した

Last updated at Posted at 2020-02-02

概要

Kaggle の自然言語処理のコンペに参戦中で、BERTをfine-tuningしたかった。Slanted triangular learning rate (STLR) が良いという情報を得て、Kerasで実装してみたところ、かなり精度が上がった。

(追記)その後シルバーメダルをゲットできた。

Slanted triangualr learning rate

参考にしたのは↓の論文。

Fig.2 を見れば分かる通り、学習初期の学習率の warm-up と中盤以降の学習率の減衰を、どちらも線形にする。傾いた三角形のようなので「slanted-triangular」。

ちなみにSTLRを使おうと思った元の論文は↓

Keras による実装

Keras の Callbacks の仕組みを使えば実現できる。STLR は epoch ごとではなく、iteration (Keras の用語でいうところの steps)ごとに学習率を変更させる必要があるので、LearningRateScheduler は使えない。Callbacks クラスを継承してスクラッチで作る必要がある。

class SlantedTriangularScheduler(Callback):

    def __init__(self,
                 lr_max: float = 0.001,
                 cut_frac: float = 0.1,
                 ratio: float = 32):
        self.lr_max = lr_max
        self.cut_frac = cut_frac
        self.ratio = ratio

    def on_train_begin(self, logs = None):
        epochs = self.params['epochs']
        steps = self.params['steps']
        self.cut = epochs * steps * self.cut_frac
        self.iteration = 0

    def on_batch_begin(self, batch: int, logs = None):
        t = self.iteration
        cut = self.cut
        if t < cut:
            p = t / cut
        else:
            p = 1 - (t - cut) / (cut * (1 / self.cut_frac - 1))
        lr = self.lr_max * (1 + p * (self.ratio - 1)) / self.ratio
        K.set_value(self.model.optimizer.lr, lr)
        self.iteration += 1

変数名などは、なるべく原著論文のEq(3)と同じものを用いている。

How to Fine-Tune BERT for Text Classification? で示されている通り、BERT の fine-tuning には↓この組み合わせがよく効いた。

ハイパーパラメーター
lr_max 2e-5
cut_frac 0.1
ratio 32
6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4