本稿では、チャットボットを訓練するための大量データ収集方法として、Twitterから会話データを集める方法を説明します。
1. はじめに
以前、KerasベースのAttention付きSeq2Seqモデルによる[チャットボットを作成しました] (https://qiita.com/gacky01/items/5cc14af9f27ce38b94a8 "Kerasで実装するSeq2Seq −その4 Attention")が、応答文生成の精度は今一つでした。
その時の日本語コーパスの規模は、10万会話対程度(1会話対は発話と応答の対を表すものとします)でしたが、これを大きくすることで、応答文生成の精度向上を図ります。
この時、問題になるのが会話データの入手先ですが、今回はTwitterのAPIを利用して、Twitterからデータを集めることにしました。
2. 本稿のゴール
TwitterAPI を用いて、100万会話対以上の会話データ取得を目指します。
なお、本稿の前提となるソフトウェア環境は、以下の通りです。
【開発環境】
- Ubuntu 16.04.5 LTS
- Python 3.6.4
- Anaconda 5.1.0
- TensorFlow-gpu 1.3.0
- Keras 2.1.4
- Jupyter 4.4.0
【実行環境】
- Ubuntu 16.04.5 LTS
- Python 3.6.7
- TensorFlow 1.4.1
- Keras 2.0.6
3. TwitterAPIの利用方法
3-1. アプリケーションの登録
TwitterAPIを利用するには、Twitterの開発者用サイトで、自分が作ろうとするアプリケーションを登録して、そのアプリケーションが認証に使用する認証キーを発行してもらう必要があります。その方法ですが、2018年の7月から変更されていますので、注意が必要です。こちらの記事に新しい登録方法が載っていますので、ご確認ください。
認証キーは、都合4種類あります(Consumer Key、Consumer Secret、Access Token、Access Token Secret)。
3-2. ライブラリのダウンロード
こちらの記事を参考にしました。以下のコマンドを投入します。
$ pip install requests requests_oauthlib
4. 会話データ収集処理
4-1. 概要
全体的な方式はこちらのサイトを参考にしました。基本的な流れは、①認証キーをもとにセッション確立②ツイート取得回数制限情報取得③応答ツイート取得(100メッセージ)④発話ツイート取得⑤不要文字削除⑥ファイル書き込み、となります。②から⑥までを繰り返して、発話文と応答文の対を収集します。
4-2. セッション確立
3-1節で取得した認証キーを引数にAPIを発行し、リターンとしてセッションオブジェクトを得ます。今回の実装では、認証キーはハードコーディングします。
4-3. ツイート取得回数制限情報取得
Twitterのツイートダウンロードには、時間当たりの回数制限があります。あと何回getメソッドの発行が可能か、発行が規定回数に達したら制限解除まで何秒待てばよいか、という情報もAPIで取得可能です。
そこで、上述のサイトを参考に実装しました。
今回の実装では、3種類のgetメソードを使用しますので、それぞれに対する制限情報を取得します。取得先URL、受信したjsonから情報を取得するためのキー、および15分あたりのgetメソード発行制限回数は以下の通りです。
種類 | 取得先URL | キー1 | キー2 | キー3 | 15分あたり制限回数 |
---|---|---|---|---|---|
応答ツイート取得 | https://api.twitter.com/1.1/search/tweets.json | 'resources' | 'search' | '/search/tweets' | 180 |
発話ツイート取得 | https://api.twitter.com/1.1/statuses/lookup.json | 'resources' | 'statuses' | '/statuses/lookup' | 900 |
制限情報取得 | https://api.twitter.com/1.1/application/rate_limit_status.json | 'resources' | 'application' | '/statuses/lookup' | 180 |
4-4. 応答ツイート取得
各ツイートはツイートされた時点では、まだ応答されていないので、仮にその後応答があったとしても、ツイート情報から応答ツイートを取得することはできません。しかし、そのツイートが別のツイートの応答かどうかは、調べる方法があります。
その手順ですが、まずツイートを取得し、その中から他のツイートに対する応答であるものを抽出し、発話ツイートを特定して取得するという流れで、発話と応答の対を収集します。
ツイートの取得には、セッションオブジェクトのgetメソッドを使用します。1回のメソッドで取得できるツイート数の上限は100なので、パラメータにその値を設定してメソッドを発行します。
なお、ループを回す関係上、おなじツイートを2度受信、処理してしまう可能性があります。これを防ぐため、処理対象のツイートをツイート時刻を使って絞り込みます。時刻処理については、こちらの記事を参考にしました。
ツイート時刻は、次のようにして取得します。getメソッドのリターンとして得られるツイート情報は、json形式になっていますが、いくつかあるキーのうち、「id」キーの値から、右側22bitを削除した残りが時刻情報になっています。各ツイートからこの値を集め、その最大値を次回処理に引き渡してやります。
次のループ処理のなかで、引き渡された値より時刻情報が大きいツイートのみ処理することで、ツイートの重複処理を防ぎます。
4-5. 発話ツイート取得
こちらの記事を参考にしました。
4-3節で発行したgetメソッドで得たツイートのうち、「in_reply_to_status_id_str」キーの値がブランクでないものは、そのidを付与されたツイートに対する応答です。今度はそのidを引数にgetメソッドを呼ぶことで、発話ツイートを取得できます。
4-6. 不要文字削除
こちらの記事を参考に、不要文字(URLなど)を削除します。
また、絵文字も「駆逐」します。絵文字の削除には苦労したので、ググって見つけた2つの方法を、両方とも取り入れました。こちらとこちらです。
4-7. ファイル書き込み
収集した発話ツイートの先頭に識別子「REQ:」を、応答ツイートの先頭に識別子「RES:」を付与して、ファイルに書き込みます。付与した識別子は、後で訓練データを作成する際に使用します。
ファイルは「tweet」フォルダに「tweetYYYY-MM-DD.txt」という名称で格納されます。ソースコードを実行する際には、実行フォルダの直下にあらかじめ、「tweet」フォルダを作成しておいてください。
4-8. ソースコード
以下にソースコードを提示します。認証キーはつぶしてありますので、使用される際には、各自で入手した認証キーで埋めてください。
# coding: utf-8
# tweet取得処理
# In[1]:
# -*- coding: utf-8 -*-
#*******************************************************************************
# *
# tweet_idから時刻情報を取得する *
# *
#*******************************************************************************
def tweet_id2time(tweet_id) :
id_bin = bin(tweet_id>>22)
tweet_time=int(id_bin,2)
tweet_time += 1288834974657
return tweet_time
#*******************************************************************************
# *
# 発話tweet本文取得 *
# *
#*******************************************************************************
def getTweet(res,start_time,reset):
res_text = json.loads(res.text)
url1 = 'https://api.twitter.com/1.1/statuses/user_timeline.json' #今回こちらは使わない
url2 = 'https://api.twitter.com/1.1/statuses/lookup.json'
cnt_req = 0
max_tweet=start_time
total_text = [] # tweet本文(発話/応答)のリスト
tweet_list = [] # n_reply_to_status_idと応答tweetの対のリスト
#--------------------------------------------------------------------------*
# *
# 応答tweet抽出取得 *
# *
#--------------------------------------------------------------------------*
for tweet in res_text['statuses']:
status_id = tweet['in_reply_to_status_id_str']
tweet_id=tweet['id'] # 応答tweetのid
if status_id != None : # 当該tweetが応答かどうかの判断
tweet_time = tweet_id2time(tweet_id)
if tweet_time <= start_time : # 前回処理より新しいtweetのみ処理する
continue
if max_tweet < tweet_time :
max_tweet = tweet_time
res_sentence = tweet['text']
#RTを対象外にする
if res_sentence[0:3] == "RT " :
continue
res_sentence = screening(res_sentence)
if res_sentence == '' :
continue
tweet_list.append([status_id,res_sentence])
if len(tweet_list) == 0 :
return max_tweet,cnt_req ,total_text
#複数status_idを連結する
id_list = tweet_list[0][0]
for i in range(1,len(tweet_list)) :
id_list += ','
id_list += tweet_list[i][0]
#--------------------------------------------------------------------------*
# *
# 発話tweet抽出取得 *
# *
#--------------------------------------------------------------------------*
#複数status_id指定で発話tweet取得
unavailableCnt = 0
while True :
try :
req = session.get(url2, params = {'id':id_list ,'count':len(tweet_list)})
except SocketError as e:
print('ソケットエラー errno=',e.errno)
if unavailableCnt > 10:
raise
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
unavailableCnt += 1
continue
if req.status_code == 503:
# 503 : Service Unavailable
if unavailableCnt > 10:
raise Exception('Twitter API error %d' % res.status_code)
unavailableCnt += 1
print ('Service Unavailable 503')
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
continue
unavailableCnt = 0
if req.status_code == 200 :
req_text = json.loads(req.text)
break
else :
raise Exception('Twitter API error %d' % res.status_code)
# 発話tweet本文スクリーニング
for i in range(0,len(tweet_list)) :
for j in range(0,len(req_text)) :
if req_text[j]['id_str'] == tweet_list[i][0] :
req_sentence = req_text[j]['text']
if len(req_text) <= 0 :
print(req_text)
continue
req_sentence = req_text[j]['text']
#RTを対象外にする
if req_sentence[0:3] == "RT " :
continue
req_sentence = screening(req_sentence)
#スクリーニングの結果、ブランクだったら対象外
if req_sentence == '' :
continue
# 発話tweetと応答tweetを対で書き込み
if req_sentence != tweet_list[i][1] :
total_text.append("REQ:"+req_sentence)
total_text.append('RES:'+tweet_list[i][1])
cnt_req += 1
max_tweet = max(max_tweet,start_time)
return max_tweet,cnt_req ,total_text
# tweet本文スクリーニング
# In[2]:
#*******************************************************************************
# *
# tweet本文スクリーニング *
# *
#*******************************************************************************
def screening(text) :
s = text
#RTを外す
if s[0:3] == "RT " :
s = s.replace(s[0:3],"")
#@screen_nameを外す
while s.find("@") != -1 :
index_at = s.find("@")
if s.find(" ") != -1 :
index_sp = s.find(" ",index_at)
if index_sp != -1 :
s = s.replace(s[index_at:index_sp+1],"")
else :
s = s.replace(s[index_at:],"")
else :
s = s.replace(s[index_at:],"")
#改行を外す
while s.find("\n") != -1 :
index_ret = s.find("\n")
s = s.replace(s[index_ret],"")
#URLを外す
s = re.sub(r'https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+', "", s)
#絵文字を「。」に置き換え その1
non_bmp_map = dict.fromkeys(range(0x10000, sys.maxunicode + 1), '。')
s = s.translate(non_bmp_map)
#絵文字を「。」に置き換え その2
s=''.join(c if c not in emoji.UNICODE_EMOJI else '。' for c in s )
#置き換えた「。」が連続していたら1つにまとめる
while s.find('。。') != -1 :
index_period = s.find('。。')
s = s.replace(s[index_period:index_period+2],'。')
#ハッシュタグを外す
while s.find('#') != -1 :
index_hash = s.find('#')
s = s[0:index_hash]
return s
# 回数制限を問合せ、アクセス可能になるまで wait する
# In[3]:
#*******************************************************************************
# *
# 回数制限を問合せ、アクセス可能になるまで wait する *
# *
#*******************************************************************************
def checkLimit(session):
unavailableCnt = 0
url = "https://api.twitter.com/1.1/application/rate_limit_status.json"
while True :
try :
res = session.get(url)
except SocketError as e:
print('erron=',e.errno)
print('ソケットエラー')
if unavailableCnt > 10:
raise
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
unavailableCnt += 1
continue
if res.status_code == 503:
# 503 : Service Unavailable
if unavailableCnt > 10:
raise Exception('Twitter API error %d' % res.status_code)
unavailableCnt += 1
print ('Service Unavailable 503')
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
continue
unavailableCnt = 0
if res.status_code != 200:
raise Exception('Twitter API error %d' % res.status_code)
remaining_search,remaining_user, remaining_limit ,reset = getLimitContext(json.loads(res.text))
if remaining_search <= 1 or remaining_user <=1 or remaining_limit <= 1:
waitUntilReset(reset+30)
else :
break
sec = reset - time.mktime(datetime.datetime.now().timetuple())
print(remaining_search,remaining_user, remaining_limit ,sec)
return reset
#*******************************************************************************
# *
# sleep処理 resetで指定した時間スリープする *
# *
#*******************************************************************************
def waitUntilReset(reset):
seconds = reset - time.mktime(datetime.datetime.now().timetuple())
seconds = max(seconds, 0)
print ('\n =====================')
print (' == waiting %d sec ==' % seconds)
print (' =====================')
sys.stdout.flush()
time.sleep(seconds + 10) # 念のため + 10 秒
#*******************************************************************************
# *
# 回数制限情報取得 *
# *
#*******************************************************************************
def getLimitContext(res_text):
# searchの制限情報
remaining_search = res_text['resources']['search']['/search/tweets']['remaining']
reset1 = res_text['resources']['search']['/search/tweets']['reset']
# lookupの制限情報
remaining_user = res_text['resources']['statuses']['/statuses/lookup']['remaining']
reset2 = res_text['resources']['statuses']['/statuses/lookup']['reset']
# 制限情報取得の制限情報
remaining_limit = res_text['resources']['application']['/application/rate_limit_status']['remaining']
reset3 = res_text['resources']['application']['/application/rate_limit_status']['reset']
return int(remaining_search),int(remaining_user),int(remaining_limit) ,max(int(reset1),int(reset2),int(reset3))
# メイン処理
# In[8]:
#*******************************************************************************
# *
# メイン処理 *
# *
#*******************************************************************************
if __name__ == '__main__':
from requests_oauthlib import OAuth1Session
import json
import datetime, time, sys
import re
import datetime
import emoji
import sys
from socket import error as SocketError
import errno
CK = '*************************' # Consumer Key
CS = '**************************************************' # Consumer Secret
AT = '**************************************************' # Access Token
AS = '*********************************************' # Accesss Token Secert
args = sys.argv
#args[1] = '私' # jupyter上で実行するとき用
session = OAuth1Session(CK, CS, AT, AS)
#--------------------------------------------------------------------------*
# *
# tweet取得処理 *
# *
#--------------------------------------------------------------------------*
total= -1
total_count = 0
cnt = 0
unavailableCnt = 0
url = 'https://api.twitter.com/1.1/search/tweets.json'
start_time = 1288834974657
while True:
#----------------
# 回数制限を確認
#----------------
#
reset = checkLimit(session)
get_time = time.mktime(datetime.datetime.now().timetuple()) #getの時刻取得
try :
res = session.get(url, params = {'q':args[1], 'count':100})
except SocketError as e:
print('ソケットエラー errno=',e.errno)
if unavailableCnt > 10:
raise
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
unavailableCnt += 1
continue
if res.status_code == 503:
# 503 : Service Unavailable
if unavailableCnt > 10:
raise Exception('Twitter API error %d' % res.status_code)
unavailableCnt += 1
print ('Service Unavailable 503')
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
continue
unavailableCnt = 0
if res.status_code != 200:
raise Exception('Twitter API error %d' % res.status_code)
#----------------
# 取得したtweetに対する発話取得とファイル書き込み
#----------------
start_time ,count ,total_text = getTweet(res,start_time,reset)
date = datetime.date.today()
fname = 'tweet/tweet'+str(date)+'.txt'
f=open(fname,'a')
for i in range(0,len(total_text)):
f.write(str(total_text[i])+"\n")
f.close()
total_count += count
print('total_count=',total_count,'start_time=',start_time)
current_time = time.mktime(datetime.datetime.now().timetuple())
# 処理時間が2秒未満なら2秒wait
if current_time - get_time < 2 :
waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 2)
#デバッグ用
if total > 0 :
cnt += 100
if total > 0 and cnt >= total:
break
5. 会話データ収集処理実行
ソースコードのファイルを置いたフォルダ上で、以下のコマンドを投入します。以下の例では、ファイル名は「conversation_py3.py」です。どこで実行しても問題ありませんが、実行前にあらかじめ、収集データ格納用の「tweet」フォルダを直下に作成しておいてください。
引数として、応答ツイート収集のための検索ワードを指定します。以下の例では、「です」を指定してあります。
$ python conversation_py3.py です
実際の実行時には、検索ワードとして、「です」のほかに、「私」「ある」を指定しました。
1日じゅう起動しっぱなしにしておきましたが、面白いようにツイートを取得できます。1日あたり、20万対以上の会話対を集めることができ、5日程度で、目標である100万会話対を取得することができました。
昼間のうちは問題なく動作していましたが、特に夜半を過ぎて午前2時以降になると、500番エラーが発生して、コマンドが停止していることが有りました。
ソケットエラーは比較的よく発生するので、リトライするようにしてあります。
6. おわりに
以上、チャットボット訓練用に、ツイートを大量に収集する方法について記述しました。今回の手法で収集したデータを用いて、以前作成した[KerasベースのSeq2Seqモデル] (https://qiita.com/gacky01/items/5cc14af9f27ce38b94a8 "Kerasで実装するSeq2Seq −その4 Attention")を訓練しましたが、その結果は、以下の投稿にまとめてあります。