1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

CloudTPUで学習済みのtransformerモデルをローカルで実行する方法

Last updated at Posted at 2020-01-31

以前書いたCloudTPUでtransformer英日翻訳モデルを学習&推論する手順では、CloudTPUでtransformer英日翻訳モデルを学習し、推論も行いました。
今回はCloudTPUで学習済みのtransformerをローカルのDockerコンテナで実行する方法を説明します。
コードはこちらです。
https://github.com/yolo-kiyoshi/transformer_python_exec

前提

GCS、ローカルで以下のディレクトリ構造でファイルが配置されているとします。

GCSディレクトリ

ディレクトリ構造
bucket
├── training/
│   └── transformer_ende/
│       ├── checkpoint
│       ├── model.ckpt-****.data-00000-of-00001
│       ├── model.ckpt-****.index
│       └── model.ckpt-****.meta
└── transformer/
    └── vocab.translate_jpen.****.subwords

ローカルディレクトリ

リポジトリをcloneします。

git clone https://github.com/yolo-kiyoshi/transformer_python_exec.git
ディレクトリ構造
.
├── Dockerfile
├── .env.sample
├── Pipfile
├── Pipfile.lock
├── README.md
├── decode.ipynb
├── docker-compose.yml
├── training/
│   └── transformer_ende/
└── transformer/

準備

Google Credentialファイル

サービスアカウントのCredentialファイル(json)をダウンロードし、README.mdと同一ディレクトリに配置します。

環境変数

.env.sampleを複製&リネームし、.envを作成します。

.env
# 上で配置したCredentialファイルのパスを記載
GOOGLE_APPLICATION_CREDENTIALS=*****.json
BUDGET_NAME=
# CloudTPUで学習したときと同一の設定
PROBLEM=translate_jpen
DATA_DIR=transformer
TRAIN_DIR=training/transformer_ende/
HPARAMS=transformer_tpu
MODEL=transformer

Dockerイメージの作成とコンテナの起動

以下のコマンドを実行後、http://localhost:8080/lab にアクセスするとJupyter labを操作できます

docker-compose up -d

Notebook

transformer学習結果をGCSからダウンロード

GCSからtransformer学習過程で作成されるcheckpointファイル一式とvocabファイルをローカルにダウンロードする。

# GCSからファイルをダウンロードするメソッド(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
    """Downloads a blob from the bucket."""
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    blob = bucket.blob(source_blob_name)

    blob.download_to_filename(destination_file_name)

    print('Blob {} downloaded to {}.'.format(
        source_blob_name,
        destination_file_name))

# GCSのファイル一覧取得メソッドを参考
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
    """Lists all the blobs in the bucket that begin with the prefix."""
    
    storage_client = storage.Client()

    # Note: Client.list_blobs requires at least package version 1.17.0.
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)

    file_list = [blob.name for blob in blobs if search_path in blob.name]
    
    return file_list

# 環境変数を設定
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']

# checkpointファイルのパス
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
# checkpointファイルをGCSからダウンロード
download_blob(BUDGET_NAME, src_file_name, dist_file_name)

# checkpointファイルから最新のcheckpointシーケンス(prefix)を取得する
import re
with open(dist_file_name) as f:
    l = f.readlines(1)
    ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
    ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)
# 最新のcheckpointに紐づくファイルリストをGCSから取得する
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)

# checkpoint.variableを一式ダウンロード
for ckpt_file in ckpt_file_list:
    download_blob(BUDGET_NAME, ckpt_file, ckpt_file)

# vocabファイルパスをGCSから取得する
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]
# vocabファイルをGCSからダウンロード
download_blob(BUDGET_NAME, vocab_file, vocab_file)

学習済みtransformerモデルをロード

GCSからダウンロードしたtransformer学習結果をもとに、transformerモデルをロードします。

# 初期化
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys


import pickle

import numpy as np

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry

# 前処理&学習で定義したPROBLEと同一のClass名にすること
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**13

enfr_problem = problems.problem(PROBLEM)
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)

from functools import wraps
import time

def stop_watch(func) :
    @wraps(func)
    def wrapper(*args, **kargs) :
        start = time.time()
        print(f'{func.__name__} started ...')
        result = func(*args,**kargs)
        elapsed_time =  time.time() - start
        print(f'elapsed_time:{elapsed_time}')
        print(f'{func.__name__} completed')
        return result
    return wrapper

@stop_watch
def translate(inputs):
    encoded_inputs = encode(inputs)
    with tfe.restore_variables_on_create(ckpt_path):
        model_output = translate_model.infer(features=encoded_inputs)["outputs"]
    return decode(model_output)

def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]
    batch_inputs = tf.reshape(inputs, [1, -1, 1])
    return {"inputs": batch_inputs}

def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)

推論

ロードしたtransformerモデルで推論します。
ローカルで実行すると、1つsentenceで30秒程度時間がかかります。

inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)
結果
> 私の猫はとてもかわいいです。

参考

Welcome to the Tensor2Tensor Colab

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?