4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

GPT-2で作るJoke BotからSlack Botまで #2 データ収集

Last updated at Posted at 2019-06-03

#イントロ
ここではモデルを学習させるためのデータ集めをしてみます。残念ながらGPT-2は日本語に対応してなく、日本語のギャグなどがたくさん(10万くらい)あるデーターセットが見つからなかったので英語のジョークのみに対応するようにしました!もしとりあえずデータが欲しければこのレポジトリーをクローンしてください。この中でjokes....npzがデーターです。Part 1はここをクリックしてください!
#データーセット
使うデーターセットは二つです。一つ目はこれです。このデーターセットのジョークはさらに3つのデーターセットのジョークが含まれています。

Reddit

一つ目はRedditというところから来たデータセットです。Redditというサイトは英語が使える方々が集うすごく人気なサイトです。このサイトは何でも話せるサイトですが話したいことによって行く場所が変わってきます。例えば機械学習の話をしたければr/MachineLearningというところで674000のメンバーと話し合うことなどができます。
子のジョークのデータセットはr/Jokesのデータから来てます。JokeはJsonに保存されていて

        "title": "My boss said to me, \"you're the worst train driver ever. How many have you derailed this year?\"",
        "body": "I said, \"I'm not sure; it's hard to keep track.\"",
        "id": "5tyytx",
        "score": 3
    }

のような形式で保存されています。英語のみに対応するジョークなのですが、Google翻訳すると

{
        "タイトル": "私の上司は私に言った、"あなたは史上最悪の列車の運転手です。今年何回脱線しましたか。
        "体": "私は言った、"私はよくわからない。追跡するのは難しいです。
        "id": "5tyytx"、
        "スコア":3
    }

という感じになります。面白さは追跡するがhard to keep trackの部分に対応しててtrackは英語で電車の線と同じ言葉なのがこのジョークの特徴です。結構こういう種類のジョークがほとんどです。ただ結構暗いジョークも多いです。例えば

        "body": "Now I have to say \"Leroy can you please paint the fence?\"",
        "id": "5tz52q",
        "score": 1,
        "title": "I hate how you cant even say black paint anymore"
    }
        "体": "今、私は言わなければならない" "リロイあなたはフェンスを塗ってください?" "
        "ID": "5tz52q"、
        "スコア":1、
        "タイトル": "私はあなたがもう黒いペンキを言うことさえできないのが嫌いだ"
    }```
のようなジョークも結構あり、fuckなどの言葉なども結構出てきます。これをきれいなジョークだけにしない理由は
1. きれいなジョークだけを選ぶやり方がわからない。
2. ほとんどのジョークはこういった類のジョークなのでそれをすべて抜いたらデータが大幅に減ってしまう。
の二つが理由です。なのできれいなジョークだけに限らないでこういった多少下品なジョークも含めることにします。
データーの収集の仕方が載っていたPythonファイルがありましたがRedditのAPIが制限されてしまったのでデーターの収集量が前より少なくなってしまったのでそれはやめました。今はRedditのデーターを大量にとった[ここ](https://files.pushshift.io/reddit/submissions/)からジョークのデーターだけとってみることを考えていますがファイルが全部大きいこともあり、まだできてないです。
これがデーターセットのほとんどのジョークを担います。
少しは面白くないと困るので、スコアが1はないと学習するデーターに入れないようにしました。Redditでは面白ければいいねとほかのユーザーができるようになっているのでこのスコアはいいねとした人数を表します。
# stupidstuff.json と wocka.json
残り二つのデーターセットはstupidstuff.orgとwocka.comからweb scrapingでとったジョークです。Jsonフォーマットはそれぞれ
```{
        "category": "Blonde Jokes",
        "body": "A blonde is walking down the street with her blouse open, exposing one of her breasts. A nearby policeman approaches her and remarks, \"Ma'am, are you aware that I could cite you for indecent exposure?\" \"Why, officer?\" asks the blonde. \"Because your blouse is open and your breast is exposed.\" \"Oh my goodness,\" exclaims the blonde, \"I must have left my baby on the bus!\"",
        "id": 14,
        "rating": 3.5
    }```
と
```{
        "title": "Infants vs Adults",
        "body": "Do infants enjoy infancy as much as adults enjoy adultery?",
        "category": "One Liners",
        "id": 17
    }```
という感じです。両方のデータセットはRedditの奴に比べても結構下品ですが一応含めてみることにしました。最近は抜いてみることとかも考えています。
ただ、見ての通りwocka.jsonの方はスコアがないのでとりあえず全部含めることにしました。stupidstuff.jsonではスコアが小数点でしたので、実際にサイトに言って確認したところ5つ橋中の評価だということがわかりました。よってこれはたぶん詰まんないジョークも含みますがスコアが一以上だと含むようにしました。これはユーザーが指定するように切り替えています。
# shortjokes.csv
OpenAIの実際にGPT-2を作った方が子のデーターセットを使用してGPT-2をジョークに対応させているとTwitterで言っていたので含めることにしました。このデータセットはcsv fileで

2 Telling my daughter garlic is good for you. Good immune system and keeps pests away.Ticks, mosquitos, vampires... men.```
という形式で数字がIDでJokeがテキストです。
今見て分かったのですがこれもRedditからデータとっていることがわかりました。なので短いジョークの身を選ぶことによってデータが結構減らしていることがわかりました。そこからデータとったら結構データが増やせることがわかりました。試してみたらもともと20MBだったものが70MBくらいになりました。
しかし、とりあえず今のところはこれでやってみます。
これで集められたデータは大体全部で100MBくらいです。shortjokes.csvとredditのjsonでかさなる部分があるのでそれをなくす必要があります。また、何千文字も続くジョークを書く方々などもおり、制限する必要が出てくると思います。これらをすべてプログラムにして実際に動かしてみると結果的に大体30MBから50MBくらいになります。
正直なところ1GBくらいのデータが欲しかったのでこれは少し残念です。ちょっと自分で実験してここここのデータも入れてみましたが残念ながらデーターは大体10MBくらいしか上がりませんでした。

データ統合

ここでいうGPT-2を使ってJokebotを作るというのはGPT-2のもうトレーニングが終了したモデルをさらにJokeでトレーニングすることです。これによって起こることは文法や文章の作り方がわかった段階でJokeを学習することができる。つまり、基礎を知った段階で応用を勉強するみたいなものです。英語でこれはFine tuningといいます。

ただ、モデルがどうやってできているかもよくわからないのにfine tuneするのは少し怖いですよね。なのでここではもうした人を探してみることにしました。運よく作っていた人がいました。ここをクリックしたらレポジトリーに行けるはずです。
クローンの方法はコマンドラインで
git clone https://github.com/nshepperd/gpt-2.git
とすることです。
これでfinetuneする方法は以下の通りです。

  1. もう学習済みのGPT-2モデルをダウンロードする。これはクローンしたレポジトリーまで行って
python download_model.py 117M

と入れたらできます。もしいいパソコンを持っていたら

python download_model.py 345M

とより大きいモデルをダウンロードしてもいいかもしれません。今のところ公開されているモデルはこの二つだけですが一個は500MBくらいでもう一個は1GBくらいです。なので実際のところこれ以上大きなモデルが公開されてもRAMのこともあってほとんどの人が使えないっというのが現実だと思います。実際に345Mの方はColabでfine tuneしようとしたらメモリーエラーになりました。
2. テキストファイルを用意する。
このテキストファイルは学習用データです。一個の文章が終わるたびに<|endoftext|>という言葉が付いていて始まりには改行のための\nがあるのが基本です。説明のためにこれはjokes.txtとします。
3. エンコードする。
この前にクローンしたレポジトリーのsrcというフォルダーに入っているファイルをすべて一個上にもっていってください。つまり、download_model.pyとencode.pyが同じフォルダーにあれば大丈夫です。ここで
python encode.py jokes.txt jokes.txt.npzと入れます。そうするとjokes.txt.npzというファイルが出てきます。
4.学習/fine tuneを始める。
基本的な学習は

train.py --dataset /path/to/encoded.npz

とすることでできます。
では、とりあえず、テキストファイルを作ってみましょう。
#テキストファイルの作り方
ここで話す内容はこのレポジトリーのmakeDatabase.pyの部分に対応します。このレポジトリーの中にもう必要なオリジナルのデータも入れといたので別々にダウンロードしなくても大丈夫です。
まず最初にライブラリーを入れます!

import json
from tqdm import tqdm
import re
import argparse

json はjson fileを使っているから入れました。tqdmは本当に存在してくれてありがたいライブラリーです。することは作業の進行状況などを見せることです。あとで使ったりするのでその時に使い方を説明します。reというのはテキストの中のパターンを使ってテキストを変形できるライブラリーです。ちょっと使います。もしもっと深い説明がよければコメントで言ってください。
最後のargparseは最近ようやく使い始めたのですがユーザーからの入力に対応するときはほとんど使うライブラリーです。fireというライブラリーも結構いいです。もし興味あったら英語ですがこのレポジトリーを見てみてください。

parser = argparse.ArgumentParser(
description='Specifying data properties.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--joke_length', metavar='len', type=int, default=800, help='Length of joke. -1 to include all jokes')
parser.add_argument('--reddit_rating', metavar='re_rat', type=int, default=2, help='Minimum upvotes for reddit jokes')
parser.add_argument('--stupidstuff_rating', metavar='stu_rat', type=float, default=3.0, help='Minimum stars out of 5 for joke to be included')
parser.add_argument('--include_wocka', metavar='inc_wocka', type=bool, default=False, help='Include wocka.json')

ここではargparseというライブラリーの実用例を見せます。入力は4つです。一つ目はjoke_length、つまりjokeの最大の長さです。これは何も入れなければ800になるようにしました。reddit_ratingはredditにおける最低のupvoteの数です。

これはdefaultは2にしました。stupidstuff_ratingはstupidstuff.jsonの5星中の最低評価です。これは3.0にしました。

include_wockaはwocka.jsonを含むか含まないかのパラメターです。defaultでは含まないようにしました。
この値などを変更したい場合などは

python makeDatabase.py --joke_lenght = 400

などとすると変えられるようになっています。実際に引数にアクセスしたい場合は

args = parser.parse_args()

とするとargs.joke_lengthとするとjoke_lengthにアクセスできるようになったりします。

args = parser.parse_args()
joke_length = args.joke_length if args.joke_length != -1 else 10**7
#assume the max joke lenght cannot exceed 10**7 characters
text_file_name = f"jokes_{str(joke_length)}_{str(args.reddit_rating)}_{str(args.stupidstuff_rating)}_{str(0 if args.include_wocka else 1)}.txt"
texts = []

こうしてもしjoke_lengthが-1ならほぼ無限までの長さのjokeを一つのjokeにするようにしてついでにテキストファイルの名前を設定しました。なのでデフォルトでは
jokes_800_2_3.0_1.txtという名前のファイルになります。言葉で明記すれば
jokes_{ジョークの最大の長さ}{Redditの最小のupvote数}{stupidstuff.jsonのスコア}_{wocka.jsonを入れたら0入れなければ1}.txtです!ちょっと1と0を反転した方がいいかもしれませんね。ちょっと考えてみます!

今度はデータを一つのtextsという配列に入れました。直接テキストファイルなどに書き込まない理由は同じジョークなどが使われていたら書き込まれたものと比較する必要があってちょっとそれは面倒でたぶんもっと遅い作業だと思ったからです。

with open("./reddit_jokes.json") as f:
	jokes = json.load(f)
	for joke in tqdm(jokes):
		body = format_data(joke["body"])
		title = format_data(joke["title"])
		score = float(joke["score"])
		if score >= args.reddit_rating and len(body) > 0 and len(title) > 0 and len(body) < joke_length:
			text = "\n\n"+fix_encoding(title + " " + body + " <|endoftext|>" )
			texts.append(text)


with open("./wocka.json") as f:
	jokes = json.load(f)
	if args.include_wocka:
		for joke in tqdm(jokes):
			body = format_data(joke["body"])
			category = format_data(joke["category"])
			title = format_data(joke["title"])
			if len(body) > 0 and len(title) > 0 and len(body) < joke_length:
				text = "\n\n"+fix_encoding(title + " " + body + " <|endoftext|>")
				texts.append(text)

with open("./stupidstuff.json") as f:
	jokes = json.load(f)
	for joke in tqdm(jokes):
		body = format_data(joke["body"])
		category = format_data(joke["category"])
		score = float(joke["rating"])
		if score >= args.stupidstuff_rating and len(body) > 0 and len(title) > 0 and len(body) < joke_length:
			text = "\n\n"+fix_encoding(body + " <|endoftext|>")
			texts.append(text)
import pandas as pd
data = pd.read_csv("jokes.csv")
indices = data.shape[0]
for i in tqdm(range(indices)):
	body = data.iloc[i]["Joke"]
	if len(body) > 0 and len(body) < joke_length:
		text = "\n\n"+fix_encoding(body + " <|endoftext|>")
		texts.append(text)
data = pd.read_csv("jokes_score_name_clean.csv")
indices = data.shape[0]
for i in tqdm(range(indices)):
	title = data.iloc[i]["q"]
	body = data.iloc[i]["a"]
	if len(body) > 0 and len(body) < joke_length:
		text = "\n\n"+fix_encoding(title + " " + body + " <|endoftext|>")
		texts.append(text)
data = pd.read_csv("qajokes1.1.2.csv")
indices = data.shape[0]
for i in tqdm(range(indices)):
	title = data.iloc[i]["Question"]
	body = data.iloc[i]["Answer"]
	if len(body) > 0 and len(body) < joke_length:
		text = "\n\n"+fix_encoding(title + " " + body + " <|endoftext|>")
		texts.append(text)

ここで指摘することが2点あります。もしほかにわからないことなどがあれば教えてください。

  1. textはすべて

    text = "\n\n"+fix_encoding(title + " " + body + " <|endoftext|>")

のように書かれています。これはまず改行して(\n\n)テキストが終わったら<|endoftext|>とテキストが終わったとモデルに知らせるようになっています。もしこれをしなければ学習を始めると改行などがなくなってしまうので、出力がすべて一行になってしまいます。また、<|endoftext|>とテキストの終わりを知らせるものがなければ、gpt-2はテキストすべてが一つのジョーク、つまり10000個くらい前の言葉も今の言葉に少しくらいは影響があると勘違いしてしまいます。なので<|endoftext|>とテキストの最後を知らせる必要があります。
2. format_dataとfix_encodingが2点目です。ここのコードは以下の通りです。

def fix_encoding(s):
   return re.sub('[\xc2-\xf4][\x80-\xbf]+',lambda m: m.group(0).encode('latin1').decode('utf8'),s)


 def format_data(data):
    data = data.replace('\n', " ").replace("\r", " ").replace("\t", " ").replace("    ", " ").replace("  ", " ").rstrip().strip().replace("  ", " ")
    data = data.replace(u"\u201c", '"').replace(u"\u201d", '"').replace(u"\u2019", "'").replace(u"\u2026", "...")
    data = fix_encoding(data)
    return data

ここのformat_dataの最初の二行はこのレポジトリーのファイルから引用しました。これをした理由は主に同じジョークが同じようにフォーマットされるようにするためです。このコードは実際にshort_jokes.csvとそれを少し変えてこっちのレポジトリーにあるjokes.csvのフォーマットに使われたコードです。なので同じジョークで違うように加工したら違うジョークと判別されてしまうのでこのようにエンコーディングしました。
fix_encodingはここの方のコードから学んだことなのですがgpt-2にテキストデータを入れてエラーにならないようにした工夫です。結構いろんな方のコードをありがたく使ってます。ただ、そのあとfix_encodingをまたしている理由は大してないですが入れといても害はないと思いますw。少しバグが怖かったのが入れている正直な理由です。

テキストファイルに書く

texts = set(texts)
for text in tqdm(texts):
	try:
		with open(text_file_name, "a+") as f:
			f.write(text)
	except:
		pass

最後にpythonでsetに変えることで同じジョークをすべて排除します。そして、とうとう書き込み始めます。
ここにある"a+"というのはもしファイルがなければ作ってもしあれば下に書き込めという指令です。tryとexceptの間にこれをすることは結構どう対処してもエンコーディングのエラーがいっぱい出てきて直し方がわからないからです。よくあるエラーは

    cp932' codec can't encode character '\u2014' in position 4: illegal multibyte sequence
    'cp932' codec can't encode character '\ufeff' in position 101: illegal multibyte sequence
    'cp932' codec can't encode character '\xf1' in position 52: illegal multibyte sequence
    'cp932' codec can't encode character '\u026f' in position 4: illegal multibyte sequence

などですが正直解決策が思いつかなくてついでにすべて違う単語でエラーなのでちょっと勝ちようのないもぐらたたきのような気がして試合放棄しましたw。

だけど結構調べてもわからなかったのでたぶん書き込めないようなデータもあるのかもしれません。だけどデーターを増やしたいので直してみようと思います。もしそれができたらここにどうやって直したかなども書いてみます。
次回はコードの編集に入っていきたいと思います。ここに投稿しました。

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?