8
9

More than 3 years have passed since last update.

Huggingface の DistilBERT を使って、Android で NLP を!

Posted at

こんにちは。@rheza_h です。
この記事はACCESS Advent Calendar 2019 22日目の記事です。

NLP は初めてですが、今回の記事は Huggingface が提供している DistilBERT を紹介し、それを使って、Android 上で動かしてみた記事です。
この記事で BERT については書かれていません。
BERT については論文 または @shotasakamoto のプレゼン資料が参考になれると思います。
社内リンク

Huggingface :hugging: とは?

Natural Language Processing (NLP - 自然言語処理) を中心に研究開発をやっています。
Huggingface は 2016 に Brooklyn, New York で始まりました。
2017 にチャットボットをリリースしました。
Huggingface は自社の NLP モデルを開発して、Hierarchical Multi-Task Learning (HTML) と呼ばれています。
Chatty, Talking Dog, Talking Egg, Boloss と言う iOS アプリを開発しています。
DistilBERT と言うモデルを NeurIPS 2019 に公開されました。

DistilBERTの話

小さく、早く、安く、軽く

state-of-the-arts の NLP モデルはほとんど large-scale language model を使われています。Transformer (Vaswani et al.,) のベースで研究されていて、最近の state-of-the-art モデルのパラメーターのサイズが大きくなっています。例えば NVIDIA が作った、MegatronLM と言うモデルは 8.3億パラメーターがあります。それは約160GBのテキストデータで学習されているそうです。

image.png
図1. NLP の進捗

DistilBERT は名前の通り、"Distil"、必要・大事な部分だけを使用して、モデルが小さくして、精度は耐えられる程度で研究しています。
Knowledge Distillation [Bucila et al., 2006, Hinton et al., 2015] と言う方法をやっています。"Teacher-Student" 学習でもよく言われています。
モデルは2つあります。

  • "学生"モデルと呼ばれています。
    • 小さなモデル
    • このモデルが"先生"モデルに似たような結果を出せるよう期待しています
  • "先生"モデル
    • ベースモデルは"学生"モデルと同じ

Knowledge Distillation の流れはいろんな流れがありますが、一番大切なのは Loss Function を組み合わせところだと思います。
image.png
pre-trainedを使う場合、

図2. Knowledge Distillation

DistilBERT は名前の通り、BERT がベースになりますが、小さいバージョンが作られました。
ほとんどのアーキテクチャはそのままにさせましたが、token-type embeddings と pooler だけを削除し、レイヤーの数を2の因数に減らします。

評価 (モデルパフォーマンス)

DistillBERT は学習のため、BERT と同じコルパスを使って、英語の Wikipedia と Toronto Book Corpus [Zhu et al., 2015]です。
8つの 16GB V100 GPUs を使って 90時間ぐらいかかったそうです。

BERT と比べたら、パラメーターの数は 40% 少ないし、パフォーマンスは約2~3%しか減っていません。
image.png

Android の実装

Huggingface は以下のレポジトリでサンプルアプリを公開します。
https://github.com/huggingface/tflite-android-transformers

DistilBERTをスクラッチから学習すると時間がかかりすぎるので、huggingface は tflite版のモデルを公開されています。
(コードは Huggingface の github で公開されています)
その tflite モデルをロードして、inference に使えばうまく動いています。

モデルのロード

private static final int NUM_LITE_THREADS = 4;
public synchronized void loadModel() {
    try {
      ByteBuffer buffer = loadModelFile(this.context.getAssets());
      Interpreter.Options opt = new Interpreter.Options();
      opt.setNumThreads(NUM_LITE_THREADS);
      tflite = new Interpreter(buffer, opt);
      Log.v(TAG, "TFLite model loaded.");
    } catch (IOException ex) {
      Log.e(TAG, ex.getMessage());
    }
}

public MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
    // MODEL_PATH は assets にあります
    try (AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH); 
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
      FileChannel fileChannel = inputStream.getChannel();
      long startOffset = fileDescriptor.getStartOffset();
      long declaredLength = fileDescriptor.getDeclaredLength();
      return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
}

predict をする前に、query の feature extraction をやることが必要です。
空白とか、query に要らない部分を抜いて、必要な部分だけを predict に渡します。

Tokenization と関係するコードは以下のところで
https://github.com/huggingface/tflite-android-transformers/tree/master/bert/src/main/java/co/huggingface/android_transformers/bertqa/tokenization

例:

感想

  • DistilBERT のおかげで BERT でも Android 端末で実行することができます。
  • tflite のフォマットでpre-trained モデルが公開されているので、すぐ使えます。
  • 答えの検索処理はコンテンツ内容によります。公開されたデータセットを使ってみるところ、Nexus 6P 端末で速度は 1~ 3秒ぐらいかかります。
  • モデルのサイズは結構大きい (250MB) です。
  • 次は日本語のモデルを試してみたいです。以下には CL-Tohoku が公開された 日本語の BERT です。 https://github.com/cl-tohoku/bert-japanese

最後に

明日は @hey3 の初めての記事です。お楽しみに!

参考

https://huggingface.co/
https://golden.com/wiki/Hugging_Face
https://techcrunch.com/2019/12/17/hugging-face-raises-15-million-to-build-the-definitive-natural-language-processing-library/
DistilBERT Paper
BERT Paper
Knowledge Distillation Paper
https://medium.com/huggingface/distilbert-8cf3380435b5
https://github.com/huggingface/tflite-android-transformers
https://nervanasystems.github.io/distiller/knowledge_distillation.html
https://towardsdatascience.com/model-distillation-and-compression-for-recommender-systems-in-pytorch-5d81c0f2c0ec
https://qiita.com/nekoumei/items/7b911c61324f16c43e7e

8
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
9