LoginSignup
8
9

More than 3 years have passed since last update.

Keras LSTMでアイドルっぽいツイートを作ってみる(文章生成)

Last updated at Posted at 2020-02-23

記事の内容

AIを使った文章の自動生成に興味がありKerasを使って文章生成をしてみました。
やった内容はアイドルグループのツイートを取得してきて学習させ、文章を作るというものです。

参考

こちらの記事を参考にさせていただきました。

Keras LSTMでサクッと文章生成をしてみる

コードは基本的には同じですが、この記事だけでは少し分かり辛かった部分などをメモとして残します。

ツイートを作ってみる

学習データの取得

まずは学習データの取得です。

get_time_lines.py
import json
import config
from requests_oauthlib import OAuth1Session
from time import sleep
import re
import emoji
from mongo_dao import MongoDAO

# 今回は使わない
# 絵文字を除去する
def remove_emoji(src_str):
    return ''.join(c for c in src_str if c not in emoji.UNICODE_EMOJI)

# APIキー設定(別ファイルのconfig.pyで定義しています)
CK = config.CONSUMER_KEY
CS = config.CONSUMER_SECRET
AT = config.ACCESS_TOKEN
ATS = config.ACCESS_TOKEN_SECRET

# 認証処理
twitter = OAuth1Session(CK, CS, AT, ATS)  

# タイムライン取得エンドポイント
url = "https://api.twitter.com/1.1/statuses/user_timeline.json"  

# 取得アカウント
necopla_menber = ['@yukino__NECOPLA', '@yurinaNECOPLA', '@riku_NECOPLA', '@miiNECOPLA', '@kaori_NECOPLA', '@sakuraNECOPLA', '@miriNECOPLA', '@renaNECOPLA']

# パラメータの定義
params = {'q': '-filter:retweets',
          'max_id': 0, # 取得を開始するID
          'count': 200}

# arg1:DB Name
# arg2:Collection Name
mongo = MongoDAO("db", "necopla_tweets")
mongo.delete_many({})

regex_twitter_account = '@[0-9a-zA-Z_]+'

for menber in necopla_menber:
    print(menber)
    del params['max_id'] # 取得を開始するIDをクリア
    # 最新の200件を取得/2回目以降はparams['max_id']に設定したIDより古いツイートを取得
    for j in range(100):
        params['screen_name'] = menber
        res = twitter.get(url, params=params)
        if res.status_code == 200:
            # API残り回数
            limit = res.headers['x-rate-limit-remaining']
            print("API remain: " + limit)
            if limit == 1:
                sleep(60*15)

            n = 0
            tweets = json.loads(res.text)
            # 処理中のアカウントからツイートが取得出来なくなったらループを抜ける
            if len(tweets) == 0:
                break
            # ツイート単位で処理する
            for tweet in tweets:
                # ツイートデータを丸ごと登録
                if not "RT @" in tweet['text'][0:4]:
                    mongo.insert_one({'tweet':re.sub(regex_twitter_account, '',tweet['text'].split('http')[0]).strip()})

                if len(tweets) >= 1:
                    params['max_id'] = tweets[-1]['id']-1

この記事の学習データには「//ネコプラ//」というグループのツイートを取得してきました。
取得したツイートは以下の要素を削除しています。
・画像のリンク
・リプライ時のアカウント

こんな感じで学習させるデータをmongoDBに突っ込んでいます。

{ "_id" : ObjectId("5e511a2ffac622266fb5801d"), "tweet" : "ソロ曲練習でもしようと思ってカラオケ行ったはいいものの普通に喉壊した" }
{ "_id" : ObjectId("5e511a2ffac622266fb5801e"), "tweet" : "まだまだいけるよ" }

mongoDB操作のソースはこちら

ツイートを学習させる

参考にさせていただいたソースとほぼ同じです。

keras_tweet_learning.py
from __future__ import print_function
from keras.callbacks import LambdaCallback
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.layers import LSTM
from keras.optimizers import RMSprop
from keras.utils.data_utils import get_file
import matplotlib.pyplot as plt   # 追加
import numpy as np
import random
import sys
import io
from mongo_dao import MongoDAO

mongo = MongoDAO("db", "necopla_tweets")
results = mongo.find()

text = ''
for result in results:
    text += result['tweet']

chars = sorted(list(set(text)))
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))

# cut the text in semi-redundant sequences of maxlen characters
maxlen = 3
step = 1
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
    sentences.append(text[i: i + maxlen])
    next_chars.append(text[i + maxlen])
print('nb sequences:', len(sentences))

print('Vectorization...')
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        x[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1


# build the model: a single LSTM
print('Build model...')
model = Sequential()
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))

optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)


def sample(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

# epoch 終了時に実行
def on_epoch_end(epoch, logs):
    # Function invoked at end of each epoch. Prints generated text.
    print()
    print('----- Generating text after Epoch: %d' % epoch)

    # start_index = random.randint(0, len(text) - maxlen - 1)
    # start_index = 0  # 毎回、「老人は老いていた」から文章生成
    for diversity in [0.2]:  # diversity = 0.2 のみとする
        print('----- diversity:', diversity)

        generated = ''
        # sentence = text[start_index: start_index + maxlen]
        sentence = '明日は'
        generated += sentence
        print('----- Generating with seed: "' + sentence + '"')
        sys.stdout.write(generated)

        for i in range(120):
            x_pred = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            preds = model.predict(x_pred, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            generated += next_char
            sentence = sentence[1:] + next_char

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()

# 学習終了時に実行
def on_train_end(logs):
    print('----- saving model...')
    model.save_weights("necopla_model" + 'w.hdf5')
    model.save("necopla_model.hdf5")

print_callback = LambdaCallback(on_epoch_end=on_epoch_end,
                                on_train_end=on_train_end)

history = model.fit(x, y,
                    batch_size=128,
                    epochs=5,
                    callbacks=[print_callback])

# Plot Training loss & Validation Loss
loss = history.history["loss"]
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, "bo", label = "Training loss" )
plt.title("Training loss")
plt.legend()
plt.savefig("loss.png")
plt.close()

変更した点としては以下です。
・学習終了時に学習データを保存する処理を追加(on_train_end)
・文字の長さ(maxlen)を8→3に変更

1点目は丸々1日かけて動かした学習データが保存されてないという痛い目を見てしまったので保存処理を追加しました。
2点目は8文字で学習を行い、その学習データで"明日は"というような3文字の言葉から始まる文章を予測すると上手く動作しませんでした。
ツイートを作成するということもあり、短めの言葉から始まる学習をさせました。

学習の過程はこのような感じになりました。

Epoch 1/5
663305/663305 [==============================] - 401s 605us/step - loss: 3.5554

----- Generating text after Epoch: 0
----- diversity: 0.2
----- Generating with seed: "明日は"
明日は!!!!!!!!!!!!!!!!!!!!!!!!!
これからもよろしくお願いします!!!

「ネコプラ//ネコプラ
#ネコプラ
#ネコプラ//ネコプラ//ネコプラ//ネコプラ//ネコプラ
#ネコプラ
#ネコプラ//ネコプラの方がいいのですが
Epoch 2/5
663305/663305 [==============================] - 459s 693us/step - loss: 3.2893

----- Generating text after Epoch: 1
----- diversity: 0.2
----- Generating with seed: "明日は"
明日はこちらです!!

#ネコプラのライブです!

#ネコプラのライブがあります!!

#ネコプラのことはありがとうございます!!!

#ネコプラのことがんばってくれてありがとうございます!!

#ネコプラのことがあったら嬉しいです、、

#ネコ
Epoch 3/5
663305/663305 [==============================] - 492s 742us/step - loss: 3.2109

----- Generating text after Epoch: 2
----- diversity: 0.2
----- Generating with seed: "明日は"
明日はこちらです!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Epoch 4/5
663305/663305 [==============================] - 501s 755us/step - loss: 3.1692

----- Generating text after Epoch: 3
----- diversity: 0.2
----- Generating with seed: "明日は"
明日はこちらです!!






#ネコプラ 
#ネコプラ
#ネコプラの人に来てくれて嬉しいよね、、!!



#ネコプラ
#ネコプラ
#ネコプラのこともっともっともっともっともっともっともっともっともっともっともっともっともっとも
Epoch 5/5
663305/663305 [==============================] - 490s 739us/step - loss: 3.1407

----- Generating text after Epoch: 4
----- diversity: 0.2
----- Generating with seed: "明日は"
明日は会いに来てね!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

追加学習

保存した学習データを使って更に学習させる
モデル構築後に保存した学習データをloadしてあげるだけ。

コードはこちら

keras_addtinal_learning.py
print('Build model...')
model = Sequential()
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))

optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
model.load_weights("necopla_modelw.hdf5")

学習したデータを使ってツイートを作る

この処理を作るときにちゃんと一個一個のコードの意味を考えました。(遅い・・。)
学習のepoch終了時に「on_epoch_end」の処理が呼ばれますが、ここではepoch終了のタイミングでその時点の学習データを使って文章を作成しています。
そのため、ツイートを作るときは、基本的にこの処理を真似れば作れます。

コードはこちら

keras_create_tweet.py
def evaluate_tweet():
    for diversity in [0.2]:  # diversity = 0.2 のみとする
        print('----- diversity:', diversity)

        generated = ''
        sentence = '明日は'
        generated += sentence
        print('----- Generating with seed: "' + sentence + '"')

        sameCharCount = 0
        for i in range(120):
            x_pred = np.zeros((1, maxlen, len(chars)))

            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            preds = model.predict(x_pred, verbose=0)[0]
            next_index = sample(preds, diversity)
            next_char = indices_char[next_index]

            if next_char == generated[-1]:
                sameCharCount += 1
                if sameCharCount >= 3: 
                    continue
            elif sameCharCount != 0:
                sameCharCount = 0
            generated += next_char
            sentence = sentence[1:] + next_char

    return generated

for i in range(10):
    tweet = evaluate_tweet()
    print('---------------- ' + str(i+1) + '回目 ---------------- ')
    print(tweet)

ツイート文章を作る時に3文字以上同じ文字が続くと文字を繋げないようにしました。
ツイートは「明日は」という言葉から始まる文章で作成しています。

出力結果はこんな感じになりました。

---------------- 1回目 ---------------- 
明日はこちらです!!!
---------------- 2回目 ---------------- 
明日はこちらです!!!
---------------- 3回目 ---------------- 
明日は会えるかな〜!!!
---------------- 4回目 ---------------- 
明日はこちらです!!!
---------------- 5回目 ---------------- 
明日はこちらです!!!
---------------- 6回目 ---------------- 
明日はこちらです!!!
---------------- 7回目 ----------------
明日はこちらです!!!
---------------- 8回目 ----------------
明日はこちらです!!!
---------------- 9回目 ----------------
明日はこちらです!!!
---------------- 10回目 ----------------
明日はこちらです!!!

なんとなくですが、それっぽい内容が出来ました。

感想

以前にChainerを使って同じことをやったことがありますが、断然Kerasでやる方が簡単で分かり易いです。
出力結果も学習回数が少ないのか短文になってしまうので、追加学習を行いながら結果を見てみようと思います。

学習方法も分かち書きで学習させる方法だと違う結果になると思うので、そちらもやってみようと思います。

8
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
9