17
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【GPT-2】Open AI GPT-2で日本語(多言語)対応ツールを開発した

Last updated at Posted at 2020-01-18

オンラインツール公開

通常のGPT-2は英語しか対応しないが、多言語対応したオンラインツールを開発した。

screencapture-mockers-io-generator-2020-01-18-22_40_58.png

設計

実際には3段構成になっており、

  • 言語判別と英語への翻訳
  • GPT-2本体処理
  • 判別されたオリジナル言語への英語からの復元

という、ニセの多言語対応である。

モデルは最近公開された最高性能の1558Mに対応したので、英語では非常に自然な文章が生成されているようだ。
対して、日本語ではいかにも機械翻訳な文章になる。

実装

WebのマイクロサービスAPIとして提供できる形でのコード全体はGithubであとで公開する予定。

Tensorflowレイヤ

Gitで公開されているコードをWebのリクエストを受けるタイミングで呼び出せように、改変した。
TensorflowでGPT-2を実装する部分。

class GPT2Engine():
    def __init__(self,
        model_name='1558M',
        seed=None,
        nsamples=1,
        batch_size=1,
        length=None,
        temperature=1,
        top_k=0,
        top_p=1,
        models_dir='models',
        return_as_generator=False,
        gpu_list='0,1'
    ):
        self.batch_size = batch_size
        self.nsamples = nsamples

        models_dir = os.path.expanduser(os.path.expandvars(models_dir))
        if self.batch_size  is None:
            self.batch_size  = 1
        assert self.nsamples % self.batch_size  == 0

        self.enc = encoder.get_encoder(model_name, models_dir)
        hparams = model.default_hparams()
        with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

        if length is None:
            length = hparams.n_ctx // 2
        elif length > hparams.n_ctx:
            raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

        from tensorflow.python.client import device_lib

        def get_available_gpus():
            local_device_protos = device_lib.list_local_devices()
            return [x.name for x in local_device_protos if x.device_type == 'GPU']

        available_gpus = get_available_gpus()
        print(available_gpus)
        if len(available_gpus)>0:
            self.config = tf.ConfigProto(
                gpu_options=tf.GPUOptions(
                    visible_device_list=gpu_list, # specify GPU number
                    #per_process_gpu_memory_fraction=0.8,
                    allow_growth=True
                )
            )
        else:
            self.config = tf.ConfigProto()

        self.seed = seed
        self.hparams = hparams
        self.length = length
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.models_dir = models_dir
        self.model_name = model_name

        graph = tf.get_default_graph()
            
        sess = tf.Session(graph=graph,config=self.config)

        self.context = tf.placeholder(tf.int32, [self.batch_size, None])
        np.random.seed(self.seed)
        tf.set_random_seed(self.seed)

        if True:
            self.output = sample.sample_sequence(
                hparams=self.hparams, length=self.length,
                context=self.context,
                batch_size=self.batch_size,
                temperature=self.temperature, top_k=self.top_k, top_p=self.top_p
            )

        else:
            self.output = sample.sample_sequence(
                hparams=self.hparams, length=self.length,
                start_token=self.enc.encoder['<|endoftext|>'],
                batch_size=self.batch_size,
                temperature=self.temperature, top_k=self.top_k, top_p=self.top_p
            )[:, 1:]

        self.saver = tf.compat.v1.train.Saver()        

        ckpt = tf.train.latest_checkpoint(os.path.join(self.models_dir, self.model_name))
        self.restored = self.saver.restore(sess, ckpt)

        print(self.restored)

        self.sess = sess

    def run(self,raw_text,
            length=None,
            temperature=1,
            top_k=0,
            top_p=1
    ):
        
        print("length={}".format(length))

        context_tokens = self.enc.encode(raw_text)

        generated = 0
        for _ in range(self.nsamples // self.batch_size):
            if raw_text is None:
                out = self.sess.run(self.output)
            else:
                out = self.sess.run(self.output, feed_dict={
                    self.context: [context_tokens for _ in range(self.batch_size)]
                })[:, len(context_tokens):]

            for i in range(self.batch_size):
                generated += 1
                text = self.enc.decode(out[i])
                print(text)

                yield text

翻訳用

今回はAWS Translate を利用している。

def get_translate_text(text,source_lang="auto",target_lang="en"):
    AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID', '')
    AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY', '')
    REGION_NAME = os.getenv('REGION_NAME', '')
    translate = boto3.client('translate', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=REGION_NAME)

    response = translate.translate_text(
        Text=text,
        SourceLanguageCode=source_lang,
        TargetLanguageCode=target_lang
    )

    return response['TranslatedText'], response.get('SourceLanguageCode')

文章補正レイヤ

WebエンドポイントとTensorflowクラスの間で、翻訳や文章補正、文字数制限などを行う。

class GPT2WebWrapper():
    def __init__(self):

        def _int(num):
            if num is None:
                return None
            else:
                return int(num)

        self.PASSAGE = os.getenv('PASSAGE', '')
        self.MODE = os.getenv('MODE', 'PREDICT_UNCONDITIONAL')
        self.MODEL_NAME = os.getenv('MODEL_NAME', '1558M')
        self.SEED = _int(os.getenv('SEED', None))
        self.NSAMPLES = _int(os.getenv('NSAMPLES', 1))
        self.BATCH_SIZE = _int(os.getenv('BATCH_SIZE', 1))
        self.LENGTH = _int(os.getenv('LENGTH', 30))
        self.USER_LANGUAGE = _int(os.getenv('USER_LANGUAGE', None))
        self.TEMPERATURE = float(os.getenv('TEMPERATURE', 1.0))
        self.TOP_K = _int(os.getenv('TOP_K', 0))
        self.TOP_P = _int(os.getenv('TOP_P', 1))
        self.MODELS_DIR = os.getenv('MODELS_DIR', 'models')
        self.GPU_MEMORY_FRACTION = float(os.getenv('GPU_MEMORY_FRACTION', 1.0))
        self.GPU_LIST = os.getenv('GPU_LIST', None)
        self.RUN_NAME = os.getenv('RUN_NAME', 'run1')
        self.LENGTH_LIMIT = _int(os.getenv('LENGTH_LIMIT', 400))

        self.engine = GPT2Engine(
            model_name=self.MODEL_NAME,
            seed=self.SEED,
            nsamples=self.NSAMPLES,
            batch_size=self.BATCH_SIZE,
            length=self.LENGTH,
            temperature=self.TEMPERATURE,
            top_k=self.TOP_K,
            top_p=self.TOP_P,
            models_dir=self.MODELS_DIR,
            return_as_generator=True,
            gpu_list=self.GPU_LIST
        )

    def run(self,_in,
            length=None,
            temperature=None,
            top_k=None,
            top_p=None,
            target_lang="en"
    ):

        empty = False

        if _in is None:
            _in = 'Sure.'
            empty = True

        _in_f = ""
        _in_l = _in
        if len(_in) > self.LENGTH_LIMIT:
            _in_f = _in[:-self.LENGTH_LIMIT]
            _in_l = _in[-self.LENGTH_LIMIT:]
            
        translated_in, detected_lang = get_translate_text(_in_l, source_lang='auto', target_lang='en')

        if not empty:
            target_lang = detected_lang

        length = length if length is not None else self.LENGTH
        temperature = temperature if temperature is not None else self.TEMPERATURE
        top_k = top_k if top_k is not None else self.TOP_K
        top_p = top_p if top_p is not None else self.TOP_P

        _out = ""
        iter = self.engine.run(translated_in,
            length=length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )
        while iter:
            try:
                _out = str(iter.__next__())
            except:
                #traceback.print_exc()
                break

        _out = translated_in+" "+_out
        _out = _out.replace("<|endoftext|>"," ")
        translated_out, input_lang = get_translate_text(_out, source_lang='en', target_lang=target_lang)
        translated_out = translated_out.replace("\n"," ")
        translated_out = _in_f + translated_out
        print(translated_out)
        return _out, translated_in, translated_out, detected_lang

課題

モデルとして日本語ネイティブ対応のほうが、より自然になるだろう。

GPT-2はもともと英語圏のニュースサイトなどを学習データとして利用しているため、そこで語られるような文章を生成するのは得意だ。GPT-2は基本的には、最初に文章を与え、その後に来るのにふさわしい文章を推測する。例えば日本語で「トランプ大統領は」と入れて生成すると、少しはまともな文章になる。

一方で、英語圏の文化から離れた単語を与えた場合、不自然になりやすい。自然な日本語を得るためには、翻訳を使用しないことはもちろんだが、日本語圏の文章を学習させることも非常に重要だと感じる。

ちなみに生成された文章からさらに再帰的に文章を生成可能だが、英語以外だと、翻訳が繰り返し行われることで、どんどん劣化していくことも問題だ。

しかし、日本語ネイティブモデルのニーズってあるのだろうか。

今後の文章AI生成の展望

現時点の大規模教師なし学習での文章生成は自然さという意味では、高くなってきたが、何かに適応できるかというとイマイチであると感じる。
私の個人的意見であるが、以下のよう課題感がある。

  • 常識や事実をいかに学習させられるか
  • 目的を持った文章を生成できるか

このような部分が解決されれば、ビジネス的にも価値がでてくると思う。

おわりに

GPT-2関連ツールは絶賛、開発中なので、興味があれば、こちらの記事も読んでほしい。

17
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?