こんにちは。@rheza_h です。
この記事はACCESS Advent Calendar 2019 22日目の記事です。
NLP は初めてですが、今回の記事は Huggingface が提供している DistilBERT を紹介し、それを使って、Android 上で動かしてみた記事です。
この記事で BERT については書かれていません。
BERT については論文 または @shotasakamoto のプレゼン資料が参考になれると思います。
社内リンク
Huggingface とは?
Natural Language Processing (NLP - 自然言語処理) を中心に研究開発をやっています。
Huggingface は 2016 に Brooklyn, New York で始まりました。
2017 にチャットボットをリリースしました。
Huggingface は自社の NLP モデルを開発して、Hierarchical Multi-Task Learning (HTML) と呼ばれています。
Chatty, Talking Dog, Talking Egg, Boloss と言う iOS アプリを開発しています。
DistilBERT と言うモデルを NeurIPS 2019 に公開されました。
DistilBERTの話
小さく、早く、安く、軽く
state-of-the-arts の NLP モデルはほとんど large-scale language model を使われています。Transformer (Vaswani et al.,) のベースで研究されていて、最近の state-of-the-art モデルのパラメーターのサイズが大きくなっています。例えば NVIDIA が作った、MegatronLM と言うモデルは 8.3億パラメーターがあります。それは約160GBのテキストデータで学習されているそうです。
DistilBERT は名前の通り、"Distil"、必要・大事な部分だけを使用して、モデルが小さくして、精度は耐えられる程度で研究しています。
Knowledge Distillation [Bucila et al., 2006, Hinton et al., 2015] と言う方法をやっています。"Teacher-Student" 学習でもよく言われています。
モデルは2つあります。
- "学生"モデルと呼ばれています。
- 小さなモデル
- このモデルが"先生"モデルに似たような結果を出せるよう期待しています
- "先生"モデル
- ベースモデルは"学生"モデルと同じ
Knowledge Distillation の流れはいろんな流れがありますが、一番大切なのは Loss Function を組み合わせところだと思います。
pre-trainedを使う場合、
図2. Knowledge Distillation
DistilBERT は名前の通り、BERT がベースになりますが、小さいバージョンが作られました。
ほとんどのアーキテクチャはそのままにさせましたが、token-type embeddings と pooler だけを削除し、レイヤーの数を2の因数に減らします。
評価 (モデルパフォーマンス)
DistillBERT は学習のため、BERT と同じコルパスを使って、英語の Wikipedia と Toronto Book Corpus [Zhu et al., 2015]です。
8つの 16GB V100 GPUs を使って 90時間ぐらいかかったそうです。
BERT と比べたら、パラメーターの数は 40% 少ないし、パフォーマンスは約2~3%しか減っていません。
Android の実装
Huggingface は以下のレポジトリでサンプルアプリを公開します。
https://github.com/huggingface/tflite-android-transformers
DistilBERTをスクラッチから学習すると時間がかかりすぎるので、huggingface は tflite版のモデルを公開されています。
(コードは Huggingface の github で公開されています)
その tflite モデルをロードして、inference に使えばうまく動いています。
モデルのロード
private static final int NUM_LITE_THREADS = 4;
public synchronized void loadModel() {
try {
ByteBuffer buffer = loadModelFile(this.context.getAssets());
Interpreter.Options opt = new Interpreter.Options();
opt.setNumThreads(NUM_LITE_THREADS);
tflite = new Interpreter(buffer, opt);
Log.v(TAG, "TFLite model loaded.");
} catch (IOException ex) {
Log.e(TAG, ex.getMessage());
}
}
public MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
// MODEL_PATH は assets にあります
try (AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
predict をする前に、query の feature extraction をやることが必要です。
空白とか、query に要らない部分を抜いて、必要な部分だけを predict に渡します。
Tokenization と関係するコードは以下のところで
https://github.com/huggingface/tflite-android-transformers/tree/master/bert/src/main/java/co/huggingface/android_transformers/bertqa/tokenization
感想
- DistilBERT のおかげで BERT でも Android 端末で実行することができます。
- tflite のフォマットでpre-trained モデルが公開されているので、すぐ使えます。
- 答えの検索処理はコンテンツ内容によります。公開されたデータセットを使ってみるところ、Nexus 6P 端末で速度は 1~ 3秒ぐらいかかります。
- モデルのサイズは結構大きい (250MB) です。
- 次は日本語のモデルを試してみたいです。以下には CL-Tohoku が公開された 日本語の BERT です。
https://github.com/cl-tohoku/bert-japanese
最後に
明日は @hey3 の初めての記事です。お楽しみに!
参考
https://huggingface.co/
https://golden.com/wiki/Hugging_Face
https://techcrunch.com/2019/12/17/hugging-face-raises-15-million-to-build-the-definitive-natural-language-processing-library/
DistilBERT Paper
BERT Paper
Knowledge Distillation Paper
https://medium.com/huggingface/distilbert-8cf3380435b5
https://github.com/huggingface/tflite-android-transformers
https://nervanasystems.github.io/distiller/knowledge_distillation.html
https://towardsdatascience.com/model-distillation-and-compression-for-recommender-systems-in-pytorch-5d81c0f2c0ec
https://qiita.com/nekoumei/items/7b911c61324f16c43e7e