3
2

More than 3 years have passed since last update.

OpenNMT-pyソースコードリーディング(1)

Posted at

OpenNMT-pyソースコードリーディングの準備

OpenNMT-pyのソースコードリーディングを行う。

前提条件

OpenNMT-pyとは

OpenNMT-pyについて、簡単に説明しておこうと思う。

OpenNMTは、オープンソースの機械翻訳フレームワークとツール集である。

PyTorchで実装されているOpenNMT-pyと、tensorflowで実装さているOpenNMT-tfがある。

ここでは、OpenNMT-py を使う。

PyTorchは、フロントがPythonで書かれており、内部は、C++で実装されている。

なので、基本的には、Pythonのコードを読んでいく。

パフォーマンス等の改善で、そのうちC++や内部で呼ばれているCUDA等に踏み込めればよいのだが。。

Transformerモデル

前に京都フリー翻訳タスク (KFTT)のデータを使って、RNNモデルの翻訳を実行した。

その時の学習は、以下のコマンドを実行していた。

# 学習を実行(10000ステップまで1000ステップごとにモデルを保存)
onmt_train -gpu_ranks 0 --train_steps 10000 -data kftt/demo -save_model demo-model -save_checkpoint_steps 1000

今回は、Transformerモデルで学習させてみたいと思う。以下のコマンドを実行すればよい。

onmt_train -data kftt/demo -save_model kyoto-trf-model -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -train_steps 200000 -max_generator_batches 2 -dropout 0.1 -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 -max_grad_norm 0 -param_init 0 -param_init_glorot -label_smoothing 0.1 -valid_steps 1000 -save_checkpoint_steps 1000 --train_steps 10000 -world_size 1 -gpu_rank 0

いろいろなパラメータを指定しているが、今なら何となくパラメータの意味がわかるのでは、ないだろうか。。

  • layers:[6] transformerのencoder/decorderの層の数
  • rnn_size:[512] これはどこの隠れ層の数だろう。。ソースコードを調べよう。transformerモデルに必要?
    • Size of rnn hidden states. Overwrites enc_rnn_size and dec_rnn_size
  • word_vec_size:[512] 単語ベクトルの次元数
  • transformer_ff:[2048] transformerのFeedFowardの隠れ層の数
  • heads:[8] MultiHead-Attentionの数
  • position_encoding: 位置エンコーディングを使う。 (encoder_type/decoder_typeのtransformerとセット)
  • max_generator_batches: [2] 単語を生成するバッチの最大値?
    • Maximum batches of words in a sequence to run the generator on in parallel. Higher is faster, but uses more memory. Set to 0 to disable.
  • dropout: dropoutさせる確率(過学習を防ぐため、いろんな箇所でdropoutさせているらしい。)
  • batch_size:[4096] バッチ数
  • batch_type:[tokens] tokens以外にsentsがある。
  • normalization:[tokens] 正規化 tokens以外にsentsがある。Normalization method of the gradient.
  • accum_count:[2] 畳み込みの数?
  • optim:[adam] 最適化アルゴリズム
  • adam_beta2:[0.998] 最適化のハイパーパラメータ
  • decay_method:[noam] {noam,noamwd,rsqrt,none}のどれかを選択する。何のこと?何かを選択するときに使うアルゴリズムかな。
  • learning_rate: [2] わからん
  • max_grad_norm: 最適化のハイパーパラメータ
  • param_init: 最適化のハイパーパラメータだろう。
  • param_init_glorot: 最適化のハイパーパラメータじゃね。
  • label_smoothing: わからん
    • Label smoothing value epsilon. Probabilities of all non-true labels will be smoothed by epsilon / (vocab_size - 1). Set to zero to turn off label smoothing. For more detailed information, see: https://arxiv.org/abs/1512.00567
  • world_size:[1] 使用するGPUの数
  • gpu_rank:[0] GPUのID

Docs » Train

ソースコードの構成

基本的に、「onmt」配下を見ていく。

OpenNMT-py$ tree -L 1 onmt
onmt
├── __init__.py
├── bin                # 実行コマンド
├── decoders           # デコーダ達
├── encoders           # エンコーダ達
├── inputters          # 入力ファイルのデータセット
├── model_builder.py   # モデルを作る人
├── models             # モデルのベース
├── modules            # MultiHead-Attention等の部品
├── opts.py            # オプション
├── tests              # テスト
├── train_single.py    # 学習実行
├── trainer.py         # 学習のメイン
├── translate          # 推論
└── utils              # 便利関数
3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2