2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

日本語版BERTを使って、Tweetのいいね数を予測するアプリを作ってみた【Django】

Last updated at Posted at 2021-12-27

こんにちは。kuro_Bです。

突然ですが、みなさんはTwitterをやっていますか?
自分は今年に入り、ようやくアカウントを作ったのですが、一つ切実な悩みがございます。

ネタツイートのいいね数を伸ばしたい!

そこで今回、
日本語版BERTという自然言語処理用の機械学習モデルを使用して、
Tweetのいいね数を事前に予測するアプリを作ってみました。

環境

  • Python: 3.9.9
  • Django: 4.0
  • transformers: 4.5.0
  • pytorch-lightning: 1.4.9

概要

今回以下の手順でアプリ作成を行いました。

  1. Twitter API v2を使用して、Tweet情報を取得
  2. 取得したTweet情報を元に、日本語版BERTを学習
  3. 学習済BERTとDjangoを使用し、アプリ作成

また、最後にAWSのEC2にてデプロイを行いました。
それでは、各手順を紹介していきます。

1.Twitter API v2を使用して、Tweet情報を取得

まずは、BERTを学習させるために必要なデータを収集します。
必要なデータは以下を考えました。

① Tweetについたいいね数(今回の予測対象のため必須)
② Tweet本文
③ Tweetに関連する情報(Tweetの投稿時間など)
④ Tweetをしたユーザーの情報(フォロワー数など)

これらのデータを、Twitter API v2を使用して収集します。

調べたところ、APIからTweetを取得するためには、検索単語を指定する必要がありそうでした。
ランダムにTweetを抽出できればベストだったのですが、
難しそうだったので、今回は検索単語を指定して、Tweetを抽出することにします。

色々考えた末、【忘年会】という単語を検索単語に指定します。←

Tweet収集コードはこちらなどを参考に、作成しました。

Tweet収集コード
extract_tweet.py
import os
import requests
import time
import json
import datetime
from box import Box

#現在時刻
dt_now = datetime.datetime.now()


config={
    'bearer_token':'<bearer_tokenを記載>'

    # 検索ワード  e.g. query = "テスト" / query = "テスト OR test"
    # OR 検索 AND検索 -検索 などしたい場合はそのように書く
    'query': "忘年会",

    'tweet_fields': ['created_at','public_metrics'] ,

    'options': {
        'max_results': '100',
        #defaultは7日前
        'start_time': (dt_now-datetime.timedelta(days=7)).strftime('%Y-%m-%dT%H:%M:%SZ'),
        #defaultは2日前。ある程度時間が経過しないと、いいね数が確定しないため
        'end_time': (dt_now-datetime.timedelta(days=2)).strftime('%Y-%m-%dT%H:%M:%SZ')
    },

    #expansionsに対応するfieldを辞書で格納
    'expansions': {'author_id':{
        'user.fields':['created_at','public_metrics']
        }}

}

config=Box(config)


#取得したい項目をパラメータに設定して、URLを作成している。
def create_url(query, tweet_fields, options, expansions):
    if len(tweet_fields)>=1:
        formatted_tweet_fields = '&tweet.fields=' + ','.join(tweet_fields)
    else:
         formatted_tweet_fields = ""

    if len(options)>=1:
        formatted_option_fields = []
        for k, v in options.items():
            formatted_option_field=f'{k}=' + f'{v}'
            formatted_option_fields.append(formatted_option_field)
        formatted_option_fields = '&'+'&'.join(formatted_option_fields)
    else:
         formatted_option_fields = ''

    expansions_key=list(expansions.keys())
    if len(expansions_key)>=1:
        formatted_expansions = '&expansions=' + ','.join(expansions_key)

        #対応するfieldを整理
        formatted_attr_fields=[]
        for key in expansions_key:
            attr_fields=expansions[key]
            for k, v in attr_fields.items():
                formatted_attr_field=f'{k}=' + ','.join(v)
                formatted_attr_fields.append(formatted_attr_field)

        formatted_attr_fields='&'+'&'.join(formatted_attr_fields)

    else:
         formatted_expansions = ''
         formatted_attr_fields = ''
         
    url = f'https://api.twitter.com/2/tweets/search/recent?query={query}{formatted_tweet_fields}{formatted_option_fields}{formatted_expansions}{formatted_attr_fields}'
    return url

#リクエスト用のheader 
def create_headers(bearer_token):
    headers = {'Authorization': f'Bearer {bearer_token}'}
    return headers

def main():

    # setting
    output_dir=os.path.dirname(__file__) + '/output'
    os.makedirs(output_dir, exist_ok=True)

    BEARER_TOKEN = config.bearer_token
    query = config.query
    tweet_fields = config.tweet_fields
    options = config.options
    expansions = config.expansions

    origin_url = create_url(query, tweet_fields, options, expansions)
    headers = create_headers(BEARER_TOKEN)
    
    #リクエスト実行。最初だけ、origin_url
    url=origin_url
    while True:
        
        response = requests.request("GET", url, headers=headers)

        if response.status_code == 200:
            #参考:https://toricor.hatenablog.com/entry/2016/01/16/160406
            with open(f'{output_dir}/{datetime.datetime.now()}.json', 'w') as f:
                json.dump(response.json(), f, indent=4, sort_keys=True, ensure_ascii=False)
            
            try:
                #もしも、next_tokenがあればそれに基づき、新しいurlを作成。
                next_token=response.json()['meta']['next_token']
                url=origin_url+'&next_token='+next_token

            #もしもnext_tokenがない(=次がない場合、ループを抜ける)
            except:
                break
        
        else:
            #api制限にかかった場合、スリープする。
            time.sleep(15*60)
        

if __name__ == "__main__":
    main()

2. 取得したTweet情報を元に、日本語版BERTを学習

次に、取得した情報を元に、日本語版BERTを学習していきます。
BERTは、Googleが提案した自然言語処理用の機械学習モデルで、
テキストデータを処理することができます。

今回は東北大学の乾研究室で作成された、日本語用のBERTを使用していきます。

モデルの学習にあたり、予測対象と予測に使用するデータ(=特徴量)を指定する必要があります。
今回は以下のように定義しました。

  • 予測対象 : ツイートのいいね数
  • 特徴量 :
    • Tweet本文
    • Tweet投稿時間
    • フォロワー数

(リツイート数なども有効な特徴量かと思いましたが、アプリに入力する時点ではわからないため、今回は除外しています。)

モデルの学習にはpytorch lightningという、深層学習系ライブラリを使用しました。

pytorch系列のライブラリは、モデルを学習する際、データセットとモデルをクラスとして事前に定義します。
今回は以下のように定義しました。

pytorch-lightning用のデータセット
dataset
class TweetLikePredDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.data = df
        self.tokenizer = tokenizer

        self.encode_text=tokenizer(
            text=self.data.text.tolist(),
            return_attention_mask=True,
            truncation=True,
            max_length=192,
            padding='max_length'
            )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        items={
            'input_ids' : torch.tensor(self.encode_text['input_ids'][idx]),
            'attention_mask' : torch.tensor(self.encode_text['attention_mask'][idx]),
            'tweet_hour' : torch.tensor(self.data['tweet_hour'][idx], dtype=torch.float32),
            'followers_count' : torch.tensor(self.data['followers_count'][idx], dtype=torch.float32),
            'like_count' : torch.tensor(self.data['like_count'][idx], dtype=torch.float32)
        }
    
        return items
pytorch-lightning用のモデル
model
class TweetLikePredModel(pl.LightningModule):
    def __init__(
        self,
        tokenizer,
        cfg,
        t_dataloader,
        v_dataloader
    ):
        #superで親クラスのメソッドを使用。
        super().__init__()
        #gradient_accumulateのため、マニュアル
        self.automatic_optimization = False

        #config
        self.weight_decay=cfg.model.weight_decay
        self.learning_rate=cfg.model.learning_rate
        self.epoch=cfg.epoch
        self.warmup_ratio=cfg.model.warmup_ratio
        self.gradient_accumulation_steps=cfg.model.gradient_accumulation_steps

        #tokenizer
        self.tokenizer=tokenizer

        #model
        self.model_config=AutoConfig.from_pretrained(cfg.model.model_path)
        self.model_config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": 0.1
            }
        )
        self.model=AutoModel.from_pretrained(cfg.model.model_path,config=self.model_config)

        
        self.regressor = nn.Sequential(
            nn.Linear(self.model_config.hidden_size+2, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

        #dataloader
        self._train_dataloader=t_dataloader
        self._valid_dataloader=v_dataloader

        #save_hyperparameter
        self.save_hyperparameters(cfg)

    #AdamWとlinearスケジューラを基本使用
    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
            ]
        optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.learning_rate,
            )
        num_training_steps=math.ceil(len(self._train_dataloader)/self.gradient_accumulation_steps)*self.epoch
        num_warmup_steps=num_training_steps*self.warmup_ratio

        scheduler=get_linear_schedule_with_warmup(
            optimizer,
            num_training_steps=num_training_steps,
            num_warmup_steps=num_warmup_steps,
        )

        return {'optimizer':optimizer,'lr_scheduler':scheduler}

    #推論の時も使う処理を記載
    def forward(self, x):

        input_ids=x['input_ids']
        attention_mask=x['attention_mask']

        #その他特徴量
        tweet_hour=x['tweet_hour'].reshape(-1,1)
        followers_count=x['followers_count'].reshape(-1,1)
        
        out=self.model(input_ids,attention_mask)
        #pooler output = CLSトークンのemb層を抽出
        out = out[1]
        out =  torch.cat([out, tweet_hour, followers_count], dim=1)
        #batch*1で予測
        qa_logits=self.regressor(out)

        return qa_logits

    #gradient_accumulateを加味しているため、マニュアルbackward
    def training_step(self,batch, batch_idx):
    
        opt = self.optimizers()
        sch = self.lr_schedulers()

        logits = self.forward(batch)
        labels = batch['like_count']
        loss = nn.MSELoss()(logits, labels)

        self.log("train_step_loss", loss, prog_bar=True)

        #if average 
        loss = loss / self.gradient_accumulation_steps

        #backward
        self.manual_backward(loss)

         # accumulate gradients of `n` batches
        if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
            opt.step()
            sch.step()
            opt.zero_grad()
        
        return {'logits': logits, 'labels': labels}

    def validation_step(self, batch, batch_idx):
        
        logits = self.forward(batch)
        labels = batch['like_count']

        loss = nn.MSELoss()(logits, labels)

        self.log("val_step_loss", loss, prog_bar=True)

        return {'logits': logits, 'labels': labels}

    #epoch終わりのloss計算
    def training_epoch_end(self, training_step_outputs):
        self._share_epoch_end(training_step_outputs,'train')

    def validation_epoch_end(self,val_step_outputs):
        self._share_epoch_end(val_step_outputs,'val')

    def _share_epoch_end(self, outputs, mode):
        all_logits = []
        all_labels = []
        for out in outputs:
            logits, labels = out['logits'], out['labels']
            all_logits.append(logits)
            all_labels.append(labels)
        all_logits = torch.cat(all_logits)
        all_labels = torch.cat(all_labels)
        loss = nn.MSELoss()(all_logits, all_labels)

        self.log(f'{mode}_epoch_loss', loss, prog_bar=True)

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._valid_dataloader

なお、自分は自前のGPUを持っていないため、モデルの学習はGoogle Colaboratoryで行っています。

3. 学習済BERTとDjangoを使用し、アプリ作成

最後に、学習したBERTを使用して、アプリを作成します。
フレームワークには、Djnagoを使用しました。
アプリのディレクトリ構成は以下のようにしています。

TweetLikePredict
┣ tweetlikepredictproject (プロジェクト)
┣ bertapp (アプリ)
┃  ┗ static
┃    ┗ XXXX.ckpt(学習したモデルの重み)
┃     ・
┃   ・
┃   ・
┃  ┗ bert.py (ここで推論用のpytorchモデルを定義)
┃  ┗ forms.py 
┃  ┗ views.py
┃  ┗ urls.py
┃   ・
┃    ・
┃    ・    
┗ templates (html)
   ┗ base.html
   ┗ input.html (Tweet情報の入力画面)
   ┗ result.html(予測いいね数を表示)

入力画面と、予測結果の表示画面は以下のようになっております。

①入力画面
スクリーンショット 2021-12-26 20.00.25.png

②出力画面
スクリーンショット 2021-12-26 19.57.04.png

最後にAWS EC2にてデプロイ

せっかくなので、最後にAWSのEC2にてデプロイを行います。

最初、無料枠でデプロイを試みたのですが、BERTの重みが大きいからか、
サーバーが立ち上がらず、デプロイできませんでした。

そのため、今回は有料枠を使用して、デプロイを行います。
使用したインスタンスはubuntuのt3.largeです。

なお、デプロイにあたり、こちらの記事を参考にしております。

肝心のアプリはこちら

ここまでお付き合いいただきありがとうございました。
そろそろ「で、アプリは?」となるのではないかと思ったので、いよいよ公開したいと思います。
完成系は、こちらです。


いいね予測君_AWS.gif

出力画面が出た時、「うおっ」と小さな声が出たことは秘密です。(自己矛盾)

#最後に
無事ツイートのいいね数を予測できた!・・となればよかったのですが、まだまだ課題があります。

例えば、ほとんどの予測結果が7~8いいね前後に固まってしまっていることが挙げられます。
こちらについては、バズりツイートといった外れ値に、モデルが引っ張られてしまったことが原因と考えています。
(「忘年会」という単語に絞った検索結果になりますが、半分以上のツイートがいいね数が0の一方、最大値は2万件以上と、
データの分布にかなり偏りがありました)

ただ、今回が初めてのアプリ開発となったのですが、
自分のアイデアが最終的に形になるというのはなんとも嬉しいものですね

なお、今回使用したコードは以下で公開しています。(「いいね予測君」と命名しています笑)

今後も、機械学習スキル、実装スキルを向上させてより良いアプリを作っていければと思います。
この度は最後まで見ていただき、ありがとうございました!

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?