LoginSignup
39
47

More than 5 years have passed since last update.

TwitterAPIを用いた会話データ収集

Last updated at Posted at 2018-11-12

 本稿では、チャットボットを訓練するための大量データ収集方法として、Twitterから会話データを集める方法を説明します。

1. はじめに

 以前、KerasベースのAttention付きSeq2Seqモデルによるチャットボットを作成しましたが、応答文生成の精度は今一つでした。

 その時の日本語コーパスの規模は、10万会話対程度(1会話対は発話と応答の対を表すものとします)でしたが、これを大きくすることで、応答文生成の精度向上を図ります。

 この時、問題になるのが会話データの入手先ですが、今回はTwitterのAPIを利用して、Twitterからデータを集めることにしました。

2. 本稿のゴール

 TwitterAPI を用いて、100万会話対以上の会話データ取得を目指します。

 なお、本稿の前提となるソフトウェア環境は、以下の通りです。

【開発環境】

  • Ubuntu 16.04.5 LTS
  • Python 3.6.4
  • Anaconda 5.1.0
  • TensorFlow-gpu 1.3.0
  • Keras 2.1.4
  • Jupyter 4.4.0

【実行環境】

  • Ubuntu 16.04.5 LTS
  • Python 3.6.7
  • TensorFlow 1.4.1
  • Keras 2.0.6

3. TwitterAPIの利用方法

3-1. アプリケーションの登録

 TwitterAPIを利用するには、Twitterの開発者用サイトで、自分が作ろうとするアプリケーションを登録して、そのアプリケーションが認証に使用する認証キーを発行してもらう必要があります。その方法ですが、2018年の7月から変更されていますので、注意が必要です。こちらの記事に新しい登録方法が載っていますので、ご確認ください。

 認証キーは、都合4種類あります(Consumer Key、Consumer Secret、Access Token、Access Token Secret)。

3-2. ライブラリのダウンロード

 こちらの記事を参考にしました。以下のコマンドを投入します。

$ pip install requests requests_oauthlib

4. 会話データ収集処理

4-1. 概要

 全体的な方式はこちらのサイトを参考にしました。基本的な流れは、①認証キーをもとにセッション確立②ツイート取得回数制限情報取得③応答ツイート取得(100メッセージ)④発話ツイート取得⑤不要文字削除⑥ファイル書き込み、となります。②から⑥までを繰り返して、発話文と応答文の対を収集します。

fig_tweet1.png

4-2. セッション確立

 3-1節で取得した認証キーを引数にAPIを発行し、リターンとしてセッションオブジェクトを得ます。今回の実装では、認証キーはハードコーディングします。

4-3. ツイート取得回数制限情報取得

 Twitterのツイートダウンロードには、時間当たりの回数制限があります。あと何回getメソッドの発行が可能か、発行が規定回数に達したら制限解除まで何秒待てばよいか、という情報もAPIで取得可能です。

 そこで、上述のサイトを参考に実装しました。

 今回の実装では、3種類のgetメソードを使用しますので、それぞれに対する制限情報を取得します。取得先URL、受信したjsonから情報を取得するためのキー、および15分あたりのgetメソード発行制限回数は以下の通りです。

種類 取得先URL キー1 キー2 キー3 15分あたり制限回数
応答ツイート取得 https://api.twitter.com/1.1/search/tweets.json 'resources' 'search' '/search/tweets' 180
発話ツイート取得 https://api.twitter.com/1.1/statuses/lookup.json 'resources' 'statuses' '/statuses/lookup' 900
制限情報取得 https://api.twitter.com/1.1/application/rate_limit_status.json 'resources' 'application' '/statuses/lookup' 180

4-4. 応答ツイート取得

 各ツイートはツイートされた時点では、まだ応答されていないので、仮にその後応答があったとしても、ツイート情報から応答ツイートを取得することはできません。しかし、そのツイートが別のツイートの応答かどうかは、調べる方法があります。

 その手順ですが、まずツイートを取得し、その中から他のツイートに対する応答であるものを抽出し、発話ツイートを特定して取得するという流れで、発話と応答の対を収集します。

 ツイートの取得には、セッションオブジェクトのgetメソッドを使用します。1回のメソッドで取得できるツイート数の上限は100なので、パラメータにその値を設定してメソッドを発行します。

 なお、ループを回す関係上、おなじツイートを2度受信、処理してしまう可能性があります。これを防ぐため、処理対象のツイートをツイート時刻を使って絞り込みます。時刻処理については、こちらの記事を参考にしました。

 ツイート時刻は、次のようにして取得します。getメソッドのリターンとして得られるツイート情報は、json形式になっていますが、いくつかあるキーのうち、「id」キーの値から、右側22bitを削除した残りが時刻情報になっています。各ツイートからこの値を集め、その最大値を次回処理に引き渡してやります。

 次のループ処理のなかで、引き渡された値より時刻情報が大きいツイートのみ処理することで、ツイートの重複処理を防ぎます。

4-5. 発話ツイート取得

 こちらの記事を参考にしました。

 4-3節で発行したgetメソッドで得たツイートのうち、「in_reply_to_status_id_str」キーの値がブランクでないものは、そのidを付与されたツイートに対する応答です。今度はそのidを引数にgetメソッドを呼ぶことで、発話ツイートを取得できます。

4-6. 不要文字削除

 こちらの記事を参考に、不要文字(URLなど)を削除します。

 また、絵文字も「駆逐」します。絵文字の削除には苦労したので、ググって見つけた2つの方法を、両方とも取り入れました。こちらこちらです。

4-7. ファイル書き込み

 収集した発話ツイートの先頭に識別子「REQ:」を、応答ツイートの先頭に識別子「RES:」を付与して、ファイルに書き込みます。付与した識別子は、後で訓練データを作成する際に使用します。

 ファイルは「tweet」フォルダに「tweetYYYY-MM-DD.txt」という名称で格納されます。ソースコードを実行する際には、実行フォルダの直下にあらかじめ、「tweet」フォルダを作成しておいてください。

4-8. ソースコード

 以下にソースコードを提示します。認証キーはつぶしてありますので、使用される際には、各自で入手した認証キーで埋めてください。

conversation_py3.py
# coding: utf-8

# tweet取得処理

# In[1]:


# -*- coding: utf-8 -*-

#*******************************************************************************
#                                                                              *
# tweet_idから時刻情報を取得する                                               *
#                                                                              *
#*******************************************************************************    
def tweet_id2time(tweet_id) :
    id_bin = bin(tweet_id>>22)
    tweet_time=int(id_bin,2)
    tweet_time += 1288834974657
    return tweet_time

#*******************************************************************************
#                                                                              *
# 発話tweet本文取得                                                            *
#                                                                              *
#*******************************************************************************    
def getTweet(res,start_time,reset):
    res_text = json.loads(res.text)
    url1 = 'https://api.twitter.com/1.1/statuses/user_timeline.json'    #今回こちらは使わない
    url2 = 'https://api.twitter.com/1.1/statuses/lookup.json'

    cnt_req = 0
    max_tweet=start_time

    total_text = []                           # tweet本文(発話/応答)のリスト
    tweet_list = []                           # n_reply_to_status_idと応答tweetの対のリスト
    #--------------------------------------------------------------------------*
    #                                                                          *
    # 応答tweet抽出取得                                                        *
    #                                                                          *
    #--------------------------------------------------------------------------*        
    for tweet in res_text['statuses']:
        status_id = tweet['in_reply_to_status_id_str']
        tweet_id=tweet['id']                  # 応答tweetのid

        if status_id != None :               # 当該tweetが応答かどうかの判断

            tweet_time = tweet_id2time(tweet_id)
            if tweet_time <= start_time :    # 前回処理より新しいtweetのみ処理する
                continue

            if max_tweet < tweet_time :
                max_tweet = tweet_time

            res_sentence = tweet['text']
            #RTを対象外にする
            if res_sentence[0:3] == "RT " :
                continue

            res_sentence = screening(res_sentence)
            if res_sentence == '' :
                continue

            tweet_list.append([status_id,res_sentence])


    if len(tweet_list) == 0 :
        return max_tweet,cnt_req ,total_text

    #複数status_idを連結する   
    id_list = tweet_list[0][0]
    for i in range(1,len(tweet_list)) :
        id_list += ','
        id_list += tweet_list[i][0]

    #--------------------------------------------------------------------------*
    #                                                                          *
    # 発話tweet抽出取得                                                        *
    #                                                                          *
    #--------------------------------------------------------------------------*   

    #複数status_id指定で発話tweet取得
    unavailableCnt = 0
    while True :
        try :
            req = session.get(url2, params = {'id':id_list ,'count':len(tweet_list)})
        except SocketError as e:
            print('ソケットエラー errno=',e.errno)
            if unavailableCnt > 10:
                raise

            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
            unavailableCnt += 1
            continue

        if req.status_code == 503:
            # 503 : Service Unavailable
            if unavailableCnt > 10:
                raise Exception('Twitter API error %d' % res.status_code)

            unavailableCnt += 1
            print ('Service Unavailable 503')
            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
            continue

        unavailableCnt = 0

        if req.status_code == 200 :
            req_text = json.loads(req.text)
            break
        else :
            raise Exception('Twitter API error %d' % res.status_code)    

    # 発話tweet本文スクリーニング
    for i in range(0,len(tweet_list)) :
        for j in range(0,len(req_text)) :
            if req_text[j]['id_str'] == tweet_list[i][0] :
                req_sentence = req_text[j]['text']

                if len(req_text) <= 0 :
                    print(req_text)
                    continue

                req_sentence = req_text[j]['text'] 
                #RTを対象外にする
                if req_sentence[0:3] == "RT " :
                    continue

                req_sentence = screening(req_sentence)

                #スクリーニングの結果、ブランクだったら対象外
                if req_sentence == '' :
                    continue   
                # 発話tweetと応答tweetを対で書き込み
                if req_sentence != tweet_list[i][1] :      
                    total_text.append("REQ:"+req_sentence)
                    total_text.append('RES:'+tweet_list[i][1])
                    cnt_req += 1

    max_tweet = max(max_tweet,start_time)
    return max_tweet,cnt_req ,total_text


# tweet本文スクリーニング

# In[2]:


#*******************************************************************************
#                                                                              *
# tweet本文スクリーニング                                                      *
#                                                                              *
#*******************************************************************************    
def screening(text) :
    s = text

    #RTを外す
    if s[0:3] == "RT " :
        s = s.replace(s[0:3],"")
    #@screen_nameを外す
    while s.find("@") != -1 :
        index_at = s.find("@")
        if s.find(" ") != -1  :
            index_sp = s.find(" ",index_at)
            if index_sp != -1 :
                s = s.replace(s[index_at:index_sp+1],"")
            else :
                s = s.replace(s[index_at:],"")
        else :
            s = s.replace(s[index_at:],"")

    #改行を外す
    while s.find("\n") != -1 :
        index_ret = s.find("\n")
        s = s.replace(s[index_ret],"")

    #URLを外す
    s = re.sub(r'https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+', "", s)
    #絵文字を「。」に置き換え その1
    non_bmp_map = dict.fromkeys(range(0x10000, sys.maxunicode + 1), '。')
    s = s.translate(non_bmp_map)
    #絵文字を「。」に置き換え その2
    s=''.join(c if c not in emoji.UNICODE_EMOJI else '。' for c in s  )

    #置き換えた「。」が連続していたら1つにまとめる
    while s.find('。。') != -1 :
        index_period = s.find('。。')
        s = s.replace(s[index_period:index_period+2],'。')

    #ハッシュタグを外す
    while s.find('#') != -1 :
        index_hash = s.find('#') 
        s = s[0:index_hash]

    return s


# 回数制限を問合せ、アクセス可能になるまで wait する

# In[3]:



#*******************************************************************************
#                                                                              *
# 回数制限を問合せ、アクセス可能になるまで wait する                           *
#                                                                              *
#*******************************************************************************
def checkLimit(session):
    unavailableCnt = 0
    url = "https://api.twitter.com/1.1/application/rate_limit_status.json"

    while True :
        try :
            res = session.get(url)
        except SocketError as e:
            print('erron=',e.errno)
            print('ソケットエラー')
            if unavailableCnt > 10:
                raise

            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
            unavailableCnt += 1
            continue

        if res.status_code == 503:
            # 503 : Service Unavailable
            if unavailableCnt > 10:
                raise Exception('Twitter API error %d' % res.status_code)

            unavailableCnt += 1
            print ('Service Unavailable 503')
            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
            continue

        unavailableCnt = 0

        if res.status_code != 200:
            raise Exception('Twitter API error %d' % res.status_code)

        remaining_search,remaining_user, remaining_limit ,reset = getLimitContext(json.loads(res.text))
        if remaining_search <= 1 or remaining_user <=1 or remaining_limit <= 1:
            waitUntilReset(reset+30)
        else :
            break

    sec = reset - time.mktime(datetime.datetime.now().timetuple())
    print(remaining_search,remaining_user, remaining_limit ,sec)
    return reset

#*******************************************************************************
#                                                                              *
# sleep処理 resetで指定した時間スリープする                                   *
#                                                                              *
#*******************************************************************************
def waitUntilReset(reset):
    seconds = reset - time.mktime(datetime.datetime.now().timetuple())
    seconds = max(seconds, 0)
    print ('\n     =====================')
    print ('     == waiting %d sec ==' % seconds)
    print ('     =====================')
    sys.stdout.flush()
    time.sleep(seconds + 10)  # 念のため + 10 秒

#*******************************************************************************
#                                                                              *
# 回数制限情報取得                                                             *
#                                                                              *
#*******************************************************************************    
def getLimitContext(res_text):
    # searchの制限情報
    remaining_search = res_text['resources']['search']['/search/tweets']['remaining']
    reset1     = res_text['resources']['search']['/search/tweets']['reset']
    # lookupの制限情報
    remaining_user = res_text['resources']['statuses']['/statuses/lookup']['remaining']
    reset2     = res_text['resources']['statuses']['/statuses/lookup']['reset']
    # 制限情報取得の制限情報
    remaining_limit = res_text['resources']['application']['/application/rate_limit_status']['remaining']
    reset3     = res_text['resources']['application']['/application/rate_limit_status']['reset']

    return int(remaining_search),int(remaining_user),int(remaining_limit) ,max(int(reset1),int(reset2),int(reset3))


# メイン処理

# In[8]:


#*******************************************************************************
#                                                                              *
# メイン処理                                                                   *
#                                                                              *
#*******************************************************************************    
if __name__ == '__main__':

    from requests_oauthlib import OAuth1Session
    import json
    import datetime, time, sys
    import re
    import datetime
    import emoji
    import sys

    from socket import error as SocketError
    import errno


    CK = '*************************'                             # Consumer Key
    CS = '**************************************************'    # Consumer Secret
    AT = '**************************************************'    # Access Token
    AS = '*********************************************'         # Accesss Token Secert

    args = sys.argv
    #args[1] = '私'                                              # jupyter上で実行するとき用

    session = OAuth1Session(CK, CS, AT, AS)

    #--------------------------------------------------------------------------*
    #                                                                          *
    # tweet取得処理                                                            *
    #                                                                          *
    #--------------------------------------------------------------------------*
    total= -1
    total_count = 0
    cnt = 0
    unavailableCnt = 0
    url = 'https://api.twitter.com/1.1/search/tweets.json'

    start_time = 1288834974657
    while True:
        #----------------
        # 回数制限を確認
        #----------------
        #
        reset  = checkLimit(session) 
        get_time = time.mktime(datetime.datetime.now().timetuple()) #getの時刻取得
        try :
            res = session.get(url, params = {'q':args[1], 'count':100})
        except SocketError as e:
            print('ソケットエラー errno=',e.errno)
            if unavailableCnt > 10:
                raise

            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
            unavailableCnt += 1
            continue

        if res.status_code == 503:
            # 503 : Service Unavailable
            if unavailableCnt > 10:
                raise Exception('Twitter API error %d' % res.status_code)

            unavailableCnt += 1
            print ('Service Unavailable 503')
            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 30)
            continue

        unavailableCnt = 0

        if res.status_code != 200:
            raise Exception('Twitter API error %d' % res.status_code)

        #----------------
        # 取得したtweetに対する発話取得とファイル書き込み
        #----------------
        start_time ,count ,total_text = getTweet(res,start_time,reset)

        date = datetime.date.today()
        fname = 'tweet/tweet'+str(date)+'.txt'

        f=open(fname,'a')
        for i in range(0,len(total_text)):
            f.write(str(total_text[i])+"\n")
        f.close()

        total_count += count
        print('total_count=',total_count,'start_time=',start_time)

        current_time = time.mktime(datetime.datetime.now().timetuple()) 
        # 処理時間が2秒未満なら2秒wait
        if current_time - get_time < 2 :
            waitUntilReset(time.mktime(datetime.datetime.now().timetuple()) + 2)

        #デバッグ用
        if total > 0 :
            cnt += 100
        if total > 0 and cnt >= total:
            break

5. 会話データ収集処理実行

 ソースコードのファイルを置いたフォルダ上で、以下のコマンドを投入します。以下の例では、ファイル名は「conversation_py3.py」です。どこで実行しても問題ありませんが、実行前にあらかじめ、収集データ格納用の「tweet」フォルダを直下に作成しておいてください。

 引数として、応答ツイート収集のための検索ワードを指定します。以下の例では、「です」を指定してあります。

$ python conversation_py3.py です

 実際の実行時には、検索ワードとして、「です」のほかに、「私」「ある」を指定しました。

 1日じゅう起動しっぱなしにしておきましたが、面白いようにツイートを取得できます。1日あたり、20万対以上の会話対を集めることができ、5日程度で、目標である100万会話対を取得することができました。

 昼間のうちは問題なく動作していましたが、特に夜半を過ぎて午前2時以降になると、500番エラーが発生して、コマンドが停止していることが有りました。

 ソケットエラーは比較的よく発生するので、リトライするようにしてあります。

6. おわりに

 以上、チャットボット訓練用に、ツイートを大量に収集する方法について記述しました。今回の手法で収集したデータを用いて、以前作成したKerasベースのSeq2Seqモデルを訓練しましたが、その結果は、以下の投稿にまとめてあります。

  1. Twitterデータを用いたチャットボットの訓練
  2. Twitterデータを用いたチャットボットの訓練 -その2 処理性能とメモリ使用量改善
39
47
12

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
39
47