41
37

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

LIFULLその2Advent Calendar 2018

Day 24

TensorFlowで対話AIを作ってみる

Last updated at Posted at 2018-12-26

機械学習エンジニアでもなんでもないのですが、趣味で TensorFlowで会話AIを作ってみた をはじめとした参考資料を元に、Seq2Seq Model(Sequence-to-Sequence Models)を利用した会話(対話)AIを作成したので、備忘録も兼ねてその作成手順をまとめておきます。

モチベーションとしては、翻訳モデルを利用して会話を構築するところで、例えば**「英語 →
日本語」の変換を「発言 → 返信」に置き換えて考えれば会話として成立するのでは?**という考えです。

会話AIやその種類についての説明は上記スライドにまとまっています。
また、TensorFlowのチュートリアルを行なったわけでもないので深い理解をしておらず、なんとなく動くがこの記事のゴールとなっています。ご了承ください。

TensorFlowとは

公式サイト : https://js.tensorflow.org/

一言で言えばGoogleが提供している数値計算ライブラリです。
テンソルと呼ばれる多次元配列が、データフローグラフ(データーの流れを定義したグラフのアーキテクチャ)上で計算されます。

環境

  • python : 2.7.15
  • tensorflow : 0.12.0
  • mecab-python : 0.996

今回はじめてtensorflowを利用したのですが、バージョンが1.x系だとmodelが動かなかったのでtensorflowのバージョンを下げて利用しました。
(なんとかなるだろうと思い色々試したのですが、時間の関係もあり挫折しました。。

手順

TensorFlowのインストール

インストールガイドを参考にTensorFlowをインストールします。

学習データの準備1

今回、データセットとして名大会話コーパスを利用しました。
中身を見てわかる通り、文章ベースの対話データセットとしては適切ではないものの3万以上のデータあるという点と、これを利用した前例があるという観点から利用しました。
実際に会話AIとして運用するような場合はデータを整形して利用したり、他のデータセットを利用するのも良いかと思います。

学習データの準備2

準備したデータセットをモデルが学習できるような形式に合わせる必要があります。

手順としては、まずこちらのツールを利用し、READMEの通りsequence.txtという生データを生成させます。
次に、sequence.txtに形態素解析を行い、以下のような行ごとに一対一対応するような2ファイルを生成させます。
MeCabの使い方はこちらを参考

input.txt
こんにちは
今日 は いい 天気 だね
・
・
output.txt
kamihorkさん こんにちは
そうだね とても 晴れて いて 良い 天気!
・
・

各データの準備

用意したinput.txtとoutput.txtを元に、次のファイル群を用意します。
今回は、訓練データと評価データの割合は7:3ぐらいでおきました。

このように、データ全体を学習データとテストデータに分割してモデルの性能を確かめる手法を、ホールドアウト方式というようです。データ量がある程度ない場合は他の評価方法を用いると良いみたいです。

ホールドアウト法は、有限のデータを学習用とテスト用に分割するので、 学習用を多くすればテスト用が減り、汎化性能の評価精度が落ちる。 逆にテスト用を多くすれば学習用が減少するので、学習そのものの精度が悪くなる。 したがって、手元のデータが大量にある場合を除いて、良い性能評価を与えないという欠点がある。

ファイル名 概要
seq2seq_model.py モデル
data_utils.py データ整形
translate.py 学習の実行
datas/train_data_in.txt 訓練データ(発言テキスト)
datas/train_data_out.txt 訓練データ(返信テキスト)
datas/test_data_in.txt 評価データ(発言テキスト)
datas/test_data_out.txt 評価データ(返信テキスト)
datas/vocab_in.txt 訓練語録データ(発言テキスト)
datas/vocab_out.txt 訓練語録データ(返信テキスト)
datas/test_data_ids_in.txt 評価語録のIDデータ(発言テキスト)
datas/test_data_ids_ut.txt 評価語録のIDデータ(返信テキスト)
seq2seq_model.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0

"""Sequence-to-sequence model with an attention mechanism."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random

import numpy as np
from six.moves import xrange  
import tensorflow as tf

from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq

import data_utils


class Seq2SeqModel(object):   

  def __init__(self, source_vocab_size, target_vocab_size, buckets, size,
               num_layers, max_gradient_norm, batch_size, learning_rate,
               learning_rate_decay_factor, use_lstm=False,
               num_samples=512, forward_only=False):
   
    self.source_vocab_size = source_vocab_size    
    self.target_vocab_size = target_vocab_size
    self.buckets = buckets
    self.batch_size = batch_size
    self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
    self.learning_rate_decay_op = self.learning_rate.assign(
        self.learning_rate * learning_rate_decay_factor)
    self.global_step = tf.Variable(0, trainable=False)

    
    output_projection = None
    softmax_loss_function = None
    
    if num_samples > 0 and num_samples < self.target_vocab_size:     
      with tf.device("/cpu:0"):
        w = tf.get_variable("proj_w", [size, self.target_vocab_size])            
        w_t = tf.transpose(w)
        b = tf.get_variable("proj_b", [self.target_vocab_size])
      output_projection = (w, b)                                        


      def sampled_loss(inputs, labels):
        with tf.device("/cpu:0"):
          labels = tf.reshape(labels, [-1, 1])
          return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
                                            self.target_vocab_size)
      softmax_loss_function = sampled_loss

    
    single_cell = rnn_cell.GRUCell(size)
    if use_lstm:                                    
      single_cell = rnn_cell.BasicLSTMCell(size)    
    cell = single_cell
    if num_layers > 1:
      cell = rnn_cell.MultiRNNCell([single_cell] * num_layers)


    def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
      return seq2seq.embedding_attention_seq2seq(
          encoder_inputs, decoder_inputs, cell, source_vocab_size,
          target_vocab_size, output_projection=output_projection,
          feed_previous=do_decode)

    self.encoder_inputs = []
    self.decoder_inputs = []
    self.target_weights = []
    for i in xrange(buckets[-1][0]):  
      self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                name="encoder{0}".format(i)))
    for i in xrange(buckets[-1][1] + 1):
      self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                name="decoder{0}".format(i)))
      self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                name="weight{0}".format(i)))


    targets = [self.decoder_inputs[i + 1]
               for i in xrange(len(self.decoder_inputs) - 1)]


    if forward_only:
      self.outputs, self.losses = seq2seq.model_with_buckets(
          self.encoder_inputs, self.decoder_inputs, targets,
          self.target_weights, buckets, self.target_vocab_size,
          lambda x, y: seq2seq_f(x, y, True),
          softmax_loss_function=softmax_loss_function)

      if output_projection is not None:
        for b in xrange(len(buckets)):
          self.outputs[b] = [tf.matmul(output, output_projection[0]) +
                             output_projection[1]
                             for output in self.outputs[b]]
    else:
      self.outputs, self.losses = seq2seq.model_with_buckets(
          self.encoder_inputs, self.decoder_inputs, targets,
          self.target_weights, buckets, self.target_vocab_size,
          lambda x, y: seq2seq_f(x, y, False),
          softmax_loss_function=softmax_loss_function)


    params = tf.trainable_variables()
    if not forward_only:
      self.gradient_norms = []
      self.updates = []
      opt = tf.train.GradientDescentOptimizer(self.learning_rate)
      for b in xrange(len(buckets)):
        gradients = tf.gradients(self.losses[b], params)
        clipped_gradients, norm = tf.clip_by_global_norm(gradients,
                                                         max_gradient_norm)
        self.gradient_norms.append(norm)
        self.updates.append(opt.apply_gradients(
            zip(clipped_gradients, params), global_step=self.global_step))

    self.saver = tf.train.Saver(tf.all_variables())

  def step(self, session, encoder_inputs, decoder_inputs, target_weights,
           bucket_id, forward_only):

    encoder_size, decoder_size = self.buckets[bucket_id]
    if len(encoder_inputs) != encoder_size:
      raise ValueError("Encoder length must be equal to the one in bucket,"
                       " %d != %d." % (len(encoder_inputs), encoder_size))
    if len(decoder_inputs) != decoder_size:
      raise ValueError("Decoder length must be equal to the one in bucket,"
                       " %d != %d." % (len(decoder_inputs), decoder_size))
    if len(target_weights) != decoder_size:
      raise ValueError("Weights length must be equal to the one in bucket,"
                       " %d != %d." % (len(target_weights), decoder_size))

    input_feed = {}
    for l in xrange(encoder_size):
      input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
    for l in xrange(decoder_size):
      input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
      input_feed[self.target_weights[l].name] = target_weights[l]

    last_target = self.decoder_inputs[decoder_size].name
    input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)   

  
    if not forward_only:
      output_feed = [self.updates[bucket_id], 
                     self.gradient_norms[bucket_id], 
                     self.losses[bucket_id]]  
    else:
      output_feed = [self.losses[bucket_id]]  
      for l in xrange(decoder_size):  
        output_feed.append(self.outputs[bucket_id][l])

    outputs = session.run(output_feed, input_feed)
    if not forward_only:
      return outputs[1], outputs[2], None  
    else:
      return None, outputs[0], outputs[1:]  



  def get_batch(self, data, bucket_id):
    
    encoder_size, decoder_size = self.buckets[bucket_id]   
    encoder_inputs, decoder_inputs = [], []
    
    for _ in xrange(self.batch_size):                                     
      encoder_input, decoder_input = random.choice(data[bucket_id])        

      encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input)) 
      encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))       
                                                                      
      decoder_pad_size = decoder_size - len(decoder_input) - 1               
      decoder_inputs.append([data_utils.GO_ID] + decoder_input +             
                            [data_utils.PAD_ID] * decoder_pad_size)          
                                                                       

    batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []


    for length_idx in xrange(encoder_size):                                  
      batch_encoder_inputs.append(
          np.array([encoder_inputs[batch_idx][length_idx]
                    for batch_idx in xrange(self.batch_size)], dtype=np.int32))


    for length_idx in xrange(decoder_size):
      batch_decoder_inputs.append(
          np.array([decoder_inputs[batch_idx][length_idx]
                    for batch_idx in xrange(self.batch_size)], dtype=np.int32))


      batch_weight = np.ones(self.batch_size, dtype=np.float32)           
      for batch_idx in xrange(self.batch_size):

        if length_idx < decoder_size - 1:
          target = decoder_inputs[batch_idx][length_idx + 1]
        if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:   
          batch_weight[batch_idx] = 0.0
      batch_weights.append(batch_weight)
    return batch_encoder_inputs, batch_decoder_inputs, batch_weights          
data_utils.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import re
import tarfile

from tensorflow.python.platform import gfile
from six.moves import urllib

_PAD = "_PAD"
_GO = "_GO"
_EOS = "_EOS"
_UNK = "_UNK"
_START_VOCAB = [_PAD, _GO, _EOS, _UNK]

PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3

_WORD_SPLIT = re.compile("([.,!?\"':;)(])")
_DIGIT_RE = re.compile(r"\d")


def basic_tokenizer(sentence):
  words = []
  for space_separated_fragment in sentence.strip().split():
    words.extend(re.split(_WORD_SPLIT, space_separated_fragment))
  
  return [w for w in words if w]


def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
                      tokenizer=None, normalize_digits=True):
  
  if os.path.exists(vocabulary_path):       
    print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
    vocab = {}
    
    f = open(data_path,"r")
    counter = 0
    for line in f:
      counter += 1
      if counter % 100 == 0:
        print("  processing line %d" % counter)
      tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
      for w in tokens:
        word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w   
        if word in vocab:       
          vocab[word] += 1
        else:
          vocab[word] = 1
    vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
    if len(vocab_list) > max_vocabulary_size:
      vocab_list = vocab_list[:max_vocabulary_size]       
      
    vocab_file = open(vocabulary_path,"w")
    for w in vocab_list:
      vocab_file.write(w + "\n")
    vocab_file.close()  
    f.close()

def initialize_vocabulary(vocabulary_path):
  
  if os.path.exists(vocabulary_path):       
    rev_vocab = []

    f = open(vocabulary_path,"r")
    rev_vocab.extend(f.readlines())
    rev_vocab = [line.strip() for line in rev_vocab]
    vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
    f.close()
    return vocab, rev_vocab

  else:
    raise ValueError("Vocabulary file %s not found.", vocabulary_path)


def sentence_to_token_ids(sentence, vocabulary,
                          tokenizer=None, normalize_digits=True):
  
  if tokenizer:
    words = tokenizer(sentence)
  else:
    words = basic_tokenizer(sentence)
  if not normalize_digits:
    return [vocabulary.get(w, UNK_ID) for w in words]
  return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words]


def data_to_token_ids(data_path, target_path, vocabulary_path,
                      tokenizer=None, normalize_digits=True):
  
  print("Tokenizing data in %s" % data_path)
  vocab, _ = initialize_vocabulary(vocabulary_path)                 
  data_file = open(data_path,"r")
  tokens_file = open(target_path,"w")
  counter = 0
  for line in data_file:
    counter += 1
    if counter % 100 == 0:
      print("  tokenizing line %d" % counter)
    token_ids = sentence_to_token_ids(line, vocab, tokenizer,         
                                            normalize_digits)
    tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")   
  data_file.close()
  tokens_file.close()      

def prepare_wmt_data(data_dir, in_vocabulary_size, out_vocabulary_size):
                    
  train_path = os.path.join(data_dir, "train_data")               
  dev_path = os.path.join(data_dir, "test_data")                     


  out_vocab_path = os.path.join(data_dir, "vocab_out.txt" )   
  in_vocab_path = os.path.join(data_dir, "vocab_in.txt" )   
  create_vocabulary(out_vocab_path, train_path + "_out.txt", out_vocabulary_size)    
  create_vocabulary(in_vocab_path, train_path + "_in.txt", in_vocabulary_size)    

  out_train_ids_path = train_path + ("_ids_out.txt" )         
  in_train_ids_path = train_path + ("_ids_in.txt" )        
  data_to_token_ids(train_path + "_out.txt", out_train_ids_path, out_vocab_path)    
  data_to_token_ids(train_path + "_in.txt", in_train_ids_path, in_vocab_path)    


  out_dev_ids_path = dev_path + ("_ids_out.txt" )             
  in_dev_ids_path = dev_path + ("_ids_in.txt" )             
  data_to_token_ids(dev_path + "_out.txt", out_dev_ids_path, out_vocab_path)         
  data_to_token_ids(dev_path + "_in.txt", in_dev_ids_path, in_vocab_path)         

  return (in_train_ids_path, out_train_ids_path,           
          in_dev_ids_path, out_dev_ids_path,
          in_vocab_path, out_vocab_path)

なお、translate.pyについては**batch_size, num_layers, size, vocab_sizeを調整することでモデルの性能(学習効率)を調整する**ことができます。

translate.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import random
import sys
import time

import tensorflow.python.platform

import numpy as np
from six.moves import xrange  
import tensorflow as tf

import data_utils
from tensorflow.models.rnn.translate import seq2seq_model
from tensorflow.python.platform import gfile


tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99,
                          "Learning rate decays by this much.")
tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
                          "Clip gradients to this norm.")
tf.app.flags.DEFINE_integer("batch_size", 64,
                            "Batch size to use during training.")
tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
tf.app.flags.DEFINE_integer("num_layers", 1024, "Number of layers in the model.")
tf.app.flags.DEFINE_integer("in_vocab_size", 50, "input vocabulary size.")
tf.app.flags.DEFINE_integer("out_vocab_size", 50, "output vocabulary size.")
tf.app.flags.DEFINE_string("data_dir", "datas", "Data directory")       
tf.app.flags.DEFINE_string("train_dir", "datas", "Training directory.")
tf.app.flags.DEFINE_integer("max_train_data_size", 0,
                            "Limit on the size of training data (0: no limit).")
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100,
                            "How many training steps to do per checkpoint.")
tf.app.flags.DEFINE_boolean("decode", False,
                            "Set to True for interactive decoding.")
tf.app.flags.DEFINE_boolean("self_test", False,
                            "Run a self-test if this is set to True.")

FLAGS = tf.app.flags.FLAGS

_buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]


def read_data(source_path, target_path, max_size=None):
  data_set = [[] for _ in _buckets]
  source_file = open(source_path,"r")
  target_file = open(target_path,"r")

  source, target = source_file.readline(), target_file.readline()     
  counter = 0
  while source and target and (not max_size or counter < max_size):
    counter += 1
    if counter % 50 == 0:                         
      print("  reading data line %d" % counter)
      sys.stdout.flush()

    source_ids = [int(x) for x in source.split()]   
    target_ids = [int(x) for x in target.split()]    
    target_ids.append(data_utils.EOS_ID)             
    for bucket_id, (source_size, target_size) in enumerate(_buckets):        
      if len(source_ids) < source_size and len(target_ids) < target_size:    
        data_set[bucket_id].append([source_ids, target_ids])                 
        break
    source, target = source_file.readline(), target_file.readline()
  return data_set

def create_model(session, forward_only):
  model = seq2seq_model.Seq2SeqModel(
      FLAGS.in_vocab_size, FLAGS.out_vocab_size, _buckets,                           
      FLAGS.size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size,      
      FLAGS.learning_rate, FLAGS.learning_rate_decay_factor,                       
      forward_only=forward_only)                                                    

  ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)                             
  if ckpt and gfile.Exists(ckpt.model_checkpoint_path):                              
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)          
    model.saver.restore(session, ckpt.model_checkpoint_path)                        
  else:
    print("Created model with fresh parameters.")
    session.run(tf.initialize_all_variables())                                      
  return model                                                                      


def train():

  print("Preparing data in %s" % FLAGS.data_dir)                                
  in_train, out_train, in_dev, out_dev, _, _ = data_utils.prepare_wmt_data(           
      FLAGS.data_dir, FLAGS.in_vocab_size, FLAGS.out_vocab_size)                     


  with tf.Session() as sess:


    print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))       
    model = create_model(sess, False)                                               

    print ("Reading development and training data (limit: %d)."     
           % FLAGS.max_train_data_size)                                             
    dev_set = read_data(in_dev, out_dev)                                             
    train_set = read_data(in_train, out_train, FLAGS.max_train_data_size)            

    train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]         
    train_total_size = float(sum(train_bucket_sizes))                               


    train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size     
                           for i in xrange(len(train_bucket_sizes))]              

    step_time, loss = 0.0, 0.0
    current_step = 0
    previous_losses = []
    while True:

      random_number_01 = np.random.random_sample()                     
      bucket_id = min([i for i in xrange(len(train_buckets_scale))      
                       if train_buckets_scale[i] > random_number_01])

      start_time = time.time()
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(   
          train_set, bucket_id)                                           

      _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 
                                   target_weights, bucket_id, False)      
      step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
      loss += step_loss / FLAGS.steps_per_checkpoint
      current_step += 1


      if current_step % FLAGS.steps_per_checkpoint == 0:

        perplexity = math.exp(loss) if loss < 300 else float('inf')
        print ("global step %d learning rate %.4f step-time %.2f perplexity "
               "%.2f" % (model.global_step.eval(), model.learning_rate.eval(),
                         step_time, perplexity))


        if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
          sess.run(model.learning_rate_decay_op)
        previous_losses.append(loss)

        checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt")
        model.saver.save(sess, checkpoint_path, global_step=model.global_step)
        step_time, loss = 0.0, 0.0

        for bucket_id in xrange(len(_buckets)):
          encoder_inputs, decoder_inputs, target_weights = model.get_batch(
              dev_set, bucket_id)
          _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
          eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
          print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
        sys.stdout.flush()


def decode():
  with tf.Session() as sess:
    print ("Hello!!")
    model = create_model(sess, True)                         
    model.batch_size = 1  

    in_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab_in.txt")     
    out_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab_out.txt" )

    in_vocab, _ = data_utils.initialize_vocabulary(in_vocab_path)        
    _, rev_out_vocab = data_utils.initialize_vocabulary(out_vocab_path)    


    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()    
    while sentence:

      token_ids = data_utils.sentence_to_token_ids(sentence, in_vocab)   

      bucket_id = min([b for b in xrange(len(_buckets))
                       if _buckets[b][0] > len(token_ids)])               

      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)

      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,      
                                       target_weights, bucket_id, True)

      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]       

      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]                      

      print(" ".join([rev_out_vocab[output] for output in outputs]))
      print("> ", end="")
      sys.stdout.flush()
      sentence = sys.stdin.readline()                                             


def self_test():

  with tf.Session() as sess:
    print("Self-test for neural translation model.")

    model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
                                       5.0, 32, 0.3, 0.99, num_samples=8)
    sess.run(tf.initialize_all_variables())


    data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])],
                [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])])
    for _ in xrange(5):  
      bucket_id = random.choice([0, 1])
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          data_set, bucket_id)
      model.step(sess, encoder_inputs, decoder_inputs, target_weights,
                 bucket_id, False)


def main(_):
  if FLAGS.self_test:
    self_test()
  elif FLAGS.decode:
    decode()
  else:
    train()

if __name__ == "__main__":
  tf.app.run()

これを、以下のようなディレクトリ構成でファイルを配置します。

myapp
├── datas
│   ├── test_data_ids_in.txt
│   ├── test_data_ids_out.txt
│   │
│   ├── test_data_in.txt
│   ├── test_data_out.txt
│   │
│   ├── train_data_ids_in.txt
│   ├── train_data_ids_out.txt
│   │
│   ├── train_data_in.txt
│   ├── train_data_out.txt
│   │
│   ├── vocab_in.txt
│   └── vocab_out.txt
│
├── data_utils.py
├── seq2seq_model.py
└── translate.py

学習開始

$ python translate.py --data_dir datas --train_dir datas

上記コマンドで学習が開始され、100stepごとにチェックポイントが生成されます。

global step 100 learning rate 0.5000 step-time 0.20 perplexity 113.56
  eval: bucket 0 perplexity 59.33
  eval: bucket 1 perplexity 51.30
  eval: bucket 2 perplexity 116.17
  eval: bucket 3 perplexity 184.61
・・・
・・・

会話をする

キリがいいところで終了し、実際にAIと会話をしてみます。
ちなみに、以下は1000stepほど踏んだ後の実行結果です。

$ python translate.py --decode --data_dir datas --train_dir datas

Hello!!
> こんにちは!!
けど けど けど けど けど そう そう そう そう *

...明らかに学習が足りていないですね。
チェックポイントがあるため、再度translate.pyを実行することで中断した途中から学習を再開してくれるようになっています。便利。

単純に学習時間を増やすのもそうですが、データ量増やしたり、データそのものを見直したり、モデルをチューニングすることでしっかりとした会話ができるようになると思います。

そこで指標となるのが**perplexityという評価指数です。
詳しくはこちらにありますが、一言で言うと
予測される候補指数を表しており、これが小さいほど予測しやすい(=性能が高い)**ということになります。

PDCAを回してどんどん改善させていきたいですね。

参考資料

今回、以下の記事等を参考に作成させていただきました。ありがとうございます。

41
37
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
37

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?