5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

もう一人の自分を作ってみる(モデル作成編)

Posted at

機械学習について少し勉強する機会があり、そのつながりで以前からやりたかった「もう一人の自分を作る」というテーマに取り組んでみました。

概要

python, kerasで自分っぽいツイートをするTwitter botを作りました。
機械学習に関しての知見があまりないので間違ってる部分も多いかもしれませんが気にしないでください。

1. データ収集

最初は以下の記事をもとにAPIを叩いて自分のアカウントの最新ツイートを取得しました。
PythonでTwitterスクレイピング&データフレーム化

......が、Twitter APIの都合上最新3000件しか取得できないということで、結局Twilogからログを取ってきました。
50000件のツイートをもとに学習を進めていきます。

2. データ処理

処理しなければいけない部分は以下の4種類です。

  • URL
  • ハッシュタグ
  • メンション
  • 公式RT

(公式RTについてはTwilogの場合設定によります)
ハッシュタグの場合は#~~~、メンションの場合は@~~~、公式RTの場合は文頭にRT @~~~:という感じなので処理してみます。

pp_texts = []

# 開始文字終了文字
start = '#'
fin = '@'

patterns = [
            r"https?://[\w!\?/\+\-_~=;\.,\*&@#\$%\(\)'\[\]]+", # URL
            r"#(\w+) *", # ハッシュタグ
            r"@([A-Za-z0-9_]+) *", # ID
            r"[#@]" # なんかよくわからん急に出てくる#と@の削除
            ]
RT_pattern = r"^RT @" # 公式RT

for l in texts + texts_twilog:
  tmp_l = l
  if re.match(RT_pattern, tmp_l): # 公式RTパターンに合致したら飛ばす
    continue
  else:
    for pattern in patterns: # 正規表現パターン合致部分を削除
      tmp_l = re.sub(pattern, '', tmp_l)
  tmp_l = re.sub(r"(\r?\n)+", " ", tmp_l) # 改行処理
  if len(tmp_l) > 0:
    pp_texts.append(start + tmp_l + fin) # 開始文字・終了文字

基本的にはreモジュールを用いて正規表現で処理していきます。
URL, ハッシュタグ, メンションは該当箇所を削除、公式RTの場合は該当文を丸ごと削除という形です。
何故か急に#やら@やらが出てくることがあったので、patternのlistに後から追加しています。また、後々の都合上、処理後の文頭に開始文字として#、終了文字として@を追記しています。


処理したデータを読み込みます。
これ以降の処理、モデル、学習などはこちらを参考に行っています。

joined_text = ''.join(pp_texts)

chars = sorted(list(set(joined_text))) # データセット中に出現する文字のリスト
char_indices = dict((c, i) for i, c in enumerate(chars)) # char->indexのdict
indices_char = dict((i, c) for i, c in enumerate(chars)) # index->charのdict

maxlen = 4
"""
入力長
日本語は文字数に対して情報量が多いのであまり長くするとうまく学習できないらしい
"""
step = 1
sentences = []
next_chars = []
for l in pp_texts:
    for i in range(0, len(l) - maxlen, step):
        sentences.append(l[i: i + maxlen]) # 入力
        next_chars.append(l[i + maxlen]) # 教師データ
# shuffle
p = list(zip(sentences, next_chars))
p = random.sample(p, len(p))
sentences, next_chars = zip(*p)

# 開始文のlist
first_words = [x for x in sentences if x.startswith(start)]

データセットに出現する文字それぞれに数字を対応させています。
(この時点ではまだindexに変換していません)

今回のモデルは、

  • 直前4文字を入力
  • 次の1文字を予測
  • 入力の後半3文字+予測1文字を再度入力
  • 終了文字が出力されたら終了

という形で文章を生成しています。
なので、window_size=4, stride=1でデータを取得することで5万件のツイートからおよそ80万件弱のデータを取得することが出来ました。すごい!

後々出力の際に必要な文頭数文字(今回は開始文字+文頭3文字の合計4文字)のリストを取得しています。

3. モデル構築

色々調べてみたのですが、文章生成に関する良い資料が見つからなかったので(何かあったら教えてください)、手探りで何パターンか試しました。
結果、試した中ではGRU(128units)-GRU(64units)-Denseが一番うまくいったのでそれを採用しました。

def build_GRU_model(maxlen, chars):
    model = Sequential()
    model.add(GRU(128, return_sequences=True, input_shape=(maxlen, len(chars))))
    model.add(GRU(64))
    model.add(Dense(len(chars), activation='softmax'))
    return model
# build the model
K.clear_session()
model = build_GRU_model(maxlen, chars)
"""
chars: list(str)
  データセット中に出現する文字のリストです(後述)。
  ストップワード除去などは行っていません。
"""

optimizer = RMSprop(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
model.summary()
"""
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
gru (GRU)                    (None, 4, 128)            1653504   
_________________________________________________________________
gru_1 (GRU)                  (None, 64)                37056     
_________________________________________________________________
dense (Dense)                (None, 4177)              271505    
=================================================================
Total params: 1,962,065
Trainable params: 1,962,065
Non-trainable params: 0
_________________________________________________________________
"""

GRUはLSTMを簡略化した時系列解析モデルです。
詳しくはこちらを参考にしてください。

4. 学習

では学習を進めていきましょう。
データとモデルは上記のコードですでに準備できているので普通に学習し、、、

たかったのですが、先述の80万件弱のデータをnumpy配列として展開したところ普通にメモリが枯渇したので今回はdata generatorを使ってみます。
generatorを用いてバッチごとに必要なデータだけをメモリに展開することでメモリを節約することが出来ます。

# data generator
def batch_iter(data, labels, batch_size, shuffle=True):
    num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1

    def data_generator():
        data_size = len(data)
        while True:
            """
            dataとlabelsが両方numpy配列じゃないと
            numpy配列によるインデックス指定はできない
            今回はメモリの関係で両方tuple(or list)なので
            違うやり方でshuffleする
            """
            if shuffle:
                p = list(zip(sentences, next_chars))
                p = random.sample(p, len(p))
                shuffled_data, shuffled_labels = zip(*p)
                # shuffled_data, shuffled_labels: tuple(str)
            else:
                shuffled_data = data
                shuffled_labels = labels

            for batch_num in range(num_batches_per_epoch):
                # batch_size分のindexを取得
                start_index = batch_num * batch_size
                end_index = min((batch_num + 1) * batch_size, data_size)
                X = shuffled_data[start_index: end_index]
                y = shuffled_labels[start_index: end_index]
                # vectorization
                # シャッフル済みのX, yからbatch_size分を取得してone-hot encode
                X_vec = np.zeros((len(X), maxlen, len(chars)), dtype=np.bool)
                y_vec = np.zeros((len(y), len(chars)), dtype=np.bool)
                for i, sentence in enumerate(X):
                    for t, char in enumerate(sentence):
                        X_vec[i, t, char_indices[char]] = 1
                    y_vec[i, char_indices[y[i]]] = 1
                yield X_vec, y_vec

    return num_batches_per_epoch, data_generator()

どこかのコードを参考にしてたんですが見つからなかったので頑張って探してください。

上記generatorと以下関数群を用いて学習していきます。

使う関数群
"""
predict
"""
def more_tweet(model, num=10, max_char=140):
    tweet_list = []
    for i in range(num):
        text_tmp = ''
        start_index = random.randint(0, len(first_words)-1) # 開始単語
        
        diversity = 0.2
        generated = ''
        sentence = first_words[start_index] # 開始sentenceリストから引っ張る
        generated += sentence
        
        # ここは開始文字**除いて**print(しない)
        text_tmp = generated[1:]

        for i in range(max_char):
            x_pred = np.zeros((1, maxlen, len(chars)))
            
            # sentenceの各文字をindex化
            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            # 次の文字を予測
            preds = model.predict(x_pred, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            # next_charが終了文字('$')だった場合breakする
            if next_char == fin:
                break

            # sentenceを右シフト
            sentence = sentence[1:] + next_char
            
            text_tmp = text_tmp + next_char
            
        print('漏れ『' + text_tmp + '')
        tweet_list.append(text_tmp)
    return tweet_list    


"""
indexをもらうやつ
temperatureが高いほど低確度のindexも出力しやすくなる
(普通にdiversityでいいのでは?)
""" 
def sample(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

"""
各epoch終了時に特定の開始文字から予測を行って学習結果を確かめる
"""
def on_epoch_end(epoch, logs):    
    start_index = 1   # 開始単語の固定
    for diversity in [0.2]: # diversity = 0.2 のみとする
        generated = ''
        sentence = first_words[start_index] # 開始sentenceリストから引っ張る
        generated += sentence
        
        # ここは開始文字込みでprint
        print('\n----- Generating with seed: "' + sentence + '"')
        
        # ここは開始文字**除いて**print
        sys.stdout.write(generated[1:])

        for i in range(100):
            x_pred = np.zeros((1, maxlen, len(chars)))
            
            # sentenceの各文字をindex化
            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            # 次の文字を予測
            preds = model.predict(x_pred, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            # next_charが終了文字('@')だった場合breakする
            if next_char == fin:
                break

            # sentenceを右シフト
            sentence = sentence[1:] + next_char
            
            # next_charをバッファにぶち込む
            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()

"""
batchあたりのlossを記録するcallback
"""
class LossHistory(tf.keras.callbacks.Callback):
        def on_train_begin(self, logs={}):
            self.losses = []

        def on_batch_end(self, batch, logs={}):
            self.losses.append(logs.get('loss'))
# datetime object
JST = timezone(timedelta(hours=+9), 'JST')
datetimeobj = datetime.now(JST)
datetimestr = datetimeobj.strftime('%Y/%m/%d %H:%M')
print(f'* * * * * Start on {datetimestr} JST * * * * *')

"""
prj_root: projectのrootdirのpath
  ローカル・クラウド環境両方で学習等行ったため変数に出してある
"""
# make today directory
os.makedirs(prj_root + 'data/model/' + datetimeobj.strftime('%m%d'), exist_ok=True)
os.makedirs(prj_root + 'data/loss/' + datetimeobj.strftime('%m%d'), exist_ok=True)

print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
save_model = ModelCheckpoint(prj_root + 'data/model/' + datetimeobj.strftime('%m%d') + '/gru_gen_check_' + datetimeobj.strftime('%H%M') + '.h5')
batch_history = LossHistory()

# get generator
train_steps, train_batches = batch_iter(sentences, next_chars, 4096)

history = model.fit_generator(generator=train_batches,
                              steps_per_epoch=train_steps,
                              epochs=50,
                              callbacks=[
                                        print_callback,
                                        save_model,
                                        batch_history
                                        ])

# save model
model.save(prj_root + 'data/model/' + datetimeobj.strftime('%m%d') + '/GRU_full_' + datetimeobj.strftime('%H%M') + '.h5')

# plot batch loss(省略)

# Plot Training loss & Validation Loss(省略)

各epochの終わりと学習の終わりにモデルを保存しています。


データとモデルなどの環境が整いました。
ということで、学習を進めます。
以下、出力結果です。

* * * * * Start on 2020/07/31 14:04 JST * * * * *
Epoch 1/50
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
187/188 [============================>.] - ETA: 0s - loss: 4.3052
----- Generating with seed: "#ラーメ"
ラーメントはある
188/188 [==============================] - 60s 321ms/step - loss: 4.3012
Epoch 2/50
187/188 [============================>.] - ETA: 0s - loss: 3.3113
----- Generating with seed: "#ラーメ"
ラーメン
188/188 [==============================] - 53s 285ms/step - loss: 3.3108
Epoch 3/50
187/188 [============================>.] - ETA: 0s - loss: 3.0455
----- Generating with seed: "#ラーメ"
ラーメン荘 歴史はそういう話
188/188 [==============================] - 54s 287ms/step - loss: 3.0455
Epoch 4/50
187/188 [============================>.] - ETA: 0s - loss: 2.8945
----- Generating with seed: "#ラーメ"
ラーメン
188/188 [==============================] - 53s 284ms/step - loss: 2.8945
Epoch 5/50
187/188 [============================>.] - ETA: 0s - loss: 2.7894
----- Generating with seed: "#ラーメ"
ラーメン屋さんになってる
188/188 [==============================] - 53s 280ms/step - loss: 2.7895

### 省略 ###

Epoch 46/50
187/188 [============================>.] - ETA: 0s - loss: 2.0753
----- Generating with seed: "#ラーメ"
ラーメン荘 歴史を刻め六甲道になった
188/188 [==============================] - 52s 279ms/step - loss: 2.0762
Epoch 47/50
187/188 [============================>.] - ETA: 0s - loss: 2.0711
----- Generating with seed: "#ラーメ"
ラーメン食べたい
188/188 [==============================] - 53s 280ms/step - loss: 2.0712
Epoch 48/50
187/188 [============================>.] - ETA: 0s - loss: 2.0676
----- Generating with seed: "#ラーメ"
ラーメンが届いた
188/188 [==============================] - 53s 280ms/step - loss: 2.0679
Epoch 49/50
187/188 [============================>.] - ETA: 0s - loss: 2.0654
----- Generating with seed: "#ラーメ"
ラーメン食ってない
188/188 [==============================] - 52s 279ms/step - loss: 2.0655
Epoch 50/50
187/188 [============================>.] - ETA: 0s - loss: 2.0605
----- Generating with seed: "#ラーメ"
ラーメン荘 歴史を刻めたりするか
188/188 [==============================] - 53s 284ms/step - loss: 2.0607

いい感じに学習が進みました(たぶん)。
現段階では特定のseedでしか検証できていませんが、個人的にはいい感じに日本語の体裁を保ったうえで変な文章を生成できている印象があります。
もちろん適切なモデル作成/データ処理が行えておらず学習しきれていない可能性もありますが、テーマがテーマなのでその不完全さもまた面白さでしょう。言い訳をするな!

5. 検証

学習がとりあえず済んだので、今度はランダムに抽出した複数のseedを元にして検証を行っていきます。(とは言っても厳密な評価基準があるわけではないのですが)

楽しみですね、
俺を満足させてみろ!!!!!!!

# test
tweet_list = more_tweet(model, num=100, max_char=140)
# 諸事情により出力をファイルに書き出している
filepath = prj_root + 'data/new_tweets.txt'
with open(filepath, mode='a', encoding='utf-8') as f:
    f.write('\n'.join(tweet_list))

以下が出力結果です。
100件(だったかな?)出力してしまったので畳んであります。
また、一部よろしくない表現のものは削りました。まだ問題がありそうな文章があれば教えてください。

出力結果
漏れ『充電コードとかあるんだよな』
漏れ『助けてくれ』
漏れ『グラブルのほうがいいらしいね』
漏れ『流星とかいうのかな』
漏れ『何の仕事とは言わんが、、、』
漏れ『ハッテントレスみたいなのはカス』
漏れ『金玉デカすぎて中身のないラーメン荘 歴史を刻めた』
漏れ『お金でもしかしてない』
漏れ『AV?』
漏れ『勝利のがヤバい』
漏れ『飲み屋行くか』
漏れ『これ安いといたい』
漏れ『買いました。 』
漏れ『何故煽る⁉️』
漏れ『こんなんだよなぁ!植木鉢ィ!なぁ!これ!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!』
漏れ『年上がれるんだけど』
漏れ『運動するから』
漏れ『歌コンテンツもんな』
漏れ『今回の方がいいかな』
漏れ『思い出して泣いてる』
漏れ『今日の晩飯はインターネットの速度はどのくらいですか?』
漏れ『御社のたとは人間になりますね』
漏れ『今日からやってたらそうなんだよなぁ!』
漏れ『らしい』
漏れ『がんばります』
漏れ『薬買いに行くか』
漏れ『人じゃないからそれにした』
漏れ『サワコシサワコ イライッッペイタ エエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエエ』
漏れ『完全に「悪い』
漏れ『ゾス!』
漏れ『やっぱりオタクのくせにコンテンツの方がいいなぁという気持ちになってる』
漏れ『ティッシュカスなぁ』
漏れ『ノビヨガッだらなぁ』
漏れ『何が届いた』
漏れ『コンテンツしか勝たん』
漏れ『字がきたな』
漏れ『口座にチャリです』
漏れ『ほんまにそれそれください。』
漏れ『あおいちゃんのおたくはそれですね』
漏れ『彼女欲しいな』
漏れ『フィクションの務日はマジで好きなんだよなぁ!』
漏れ『今日の晩飯はインターネットの速度はどのくらいですか?』
漏れ『ノゾミの間に合わないといけないと思います。』
漏れ『クッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッッ』
漏れ『はーい』
漏れ『ゴミカスなんですか?』
漏れ『もテイカツーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーーー』
漏れ『すたー』
漏れ『違うならそれはそれですね』
漏れ『陰キャのほうが早いな』
漏れ『対戦よろしくお願いします。』
漏れ『香澄真昼、ちょっと違うけど』
漏れ『エグいねん』
漏れ『あんスペックでは?』
漏れ『バイトの時点でもうギリギリやしてる』
漏れ『わかなくなってきたな』
漏れ『キタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキタキ』
漏れ『無力感あるからそれにしよか』
漏れ『行きたいなぁ』
漏れ『金玉のほうがいいよ』
漏れ『頼むからコバコれてたらそうなんだよなぁ!』
漏れ『おめでとうございます。』
漏れ『AEBF :参戦ID 参加者募集! Lv120 シヴァ 』
漏れ『しばらくで一番好きなんだよな』
漏れ『何も来るなぁ』
漏れ『それ、もちかんかい!』
漏れ『工学部は卒業できるようになった』
漏れ『会見始めたんだよなぁ!』
漏れ『???????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????』
漏れ『夢のレインボーライブを見ています』
漏れ『明日の予定です』
漏れ『面白そうなもんよ』
漏れ『テノチヨ』
漏れ『逃してない』
漏れ『うちにも違いしてる』
漏れ『いいねー!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!』
漏れ『おたくクン、それはマジでやめてくれ』
漏れ『帰省するときはマジで好きなんだよなぁ!』
漏れ『俺もライブ忘れてた』
漏れ『ねむすぎ』
漏れ『これ 』
漏れ『酒イキリオタクがいる』
漏れ『個人情報ありません。』
漏れ『イエスクレイパー』
漏れ『途中ですか?』
漏れ『またオタクしかないんだよなぁ!オイ!なぁ!これ!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!』
漏れ『マジで美味いなぁ』
漏れ『からくれてもいいんだけど』
漏れ『Twitterですね』
漏れ『岩瀬超えでも一緒にプリティーシリーズ06 / わか・ふうり from AIKATSU☆STAR☆ANIS / アイカツはやめたい』
漏れ『俺も卵になってる』
漏れ『マジか?』
漏れ『キッチンペーパーマシューするか』
漏れ『救急車でイキるオタク、それも好き』

いい感じではないでしょうか!!!!!(何が?)

一部意味不明になっていたり文法がおかしかったりするものもありますが、全体的に味があっていいなと思います。

†𝑶𝒕𝒂𝒌𝒖 𝑾𝒐𝒓𝒅𝒔†が多かったり、某国民的大正義スマホファンタジーゲームの救援を出していたり、なかなかデータセットのバイアスが上手く表現されてる気がしますね。

5. まとめ

いろいろ放置してたら実施から執筆まで1か月くらいかかってしまった、、、

機械学習初学者も初学者のペーペーなのですが、サクッといい感じの成果が出せるとはフレームワークさまさまですね。やっぱりONEIROS生まれのKさんはスゴイ、俺はいろんな意味で思った
今回は経験が浅く学習中によく使っていたということもありKerasを使いましたが、これからTensorflow、PyTorchなど別のフレームワークも手を出していきたいです。

(今度はもっと生産的なテーマに取り組まなきゃ......)

6. おまけ:Twitter bot化

実はもう運用しているのですが、今回のモデルをGCPのGoogle Compute Engine上に載せてbot化しています。
この記事内で触れる予定でしたが、量が結構多くなってしまったので別記事にします(これから書きます)。

簡単な内容にはなりますが、「Twitterのbotに興味がある」「とりあえずクラウド触ってみたい」という方には多少役に立つようにしたいな~と思っております。少々お待ちください。

それでは。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?