目次
- やりたかったこと
- Qiita投稿を取得する
- Tweet2Vecを使う
- GPUインスタンスを使う
- タグ予測結果
- 類似投稿を出してみる
- 考察と課題
やりたかったこと
- 短い日本語文書(ツイートなど)を分類したい
- ニューラルネットワーク使いたい
- 形態素解析せずにやりたい
SNSの投稿などを処理する場合、誤字、脱字、スラング、新語、絵文字、顔文字、外国語、専門用語、表記ゆれなどが多く含まれるため、形態素解析器を用いたアプローチは不利に思われる。近年のNLP論文を読んでいると単語レベルではなく文字レベルで学習させる方向にいっているようなので、その流れに乗ってみる。日本語は一文字あたりの情報量が大きいので英語より有利だと思う。
文書が長すぎず、形態素解析しにくそうで、ある程度トピックにまとまりがある題材として、Qiitaの投稿をタイトルのみで分類してみることにした。Qiitaの本文はmarkdownまたはHTMLで、扱いが難しそうなので今回は使わない。
Qiita投稿を取得する
Qiita APIを使用。
本文やその他のメタ情報も今後使う可能性があるので、APIが返すJSON全体を、PythonでPeeweeを使いPostgreSQLに保存した。
from peewee import *
from playhouse.postgres_ext import PostgresqlExtDatabase, BinaryJSONField
db = PostgresqlExtDatabase(
'mytable',
user='myuser',
password='mypassword',
register_hstore=False,
autocommit=True,
autorollback=True
)
db.connect()
# モデル定義
class BaseModel(Model):
"""A base model that will use our Postgresql database"""
class Meta:
database = db
class QiitaItem(BaseModel):
item_id = CharField(unique=True)
created_at = DateTimeField(index=True)
raw = BinaryJSONField()
class Meta:
db_table = 'qiita_items'
# テーブル作成
db.create_tables([QiitaItem])
Qiita APIの呼び出し
トークンを使わないと呼び出し回数制限がきついので注意。
トークンを使っても毎時1,000リクエストが上限なので注意。
import json, requests
url = 'http://qiita.com/api/v2/items'
headers = {'Authorization': 'Bearer myauthorizationtoken1234567890'}
# とりあえず最新10,000投稿を取得する例
# それ以上の投稿は、web検索と同じqueryフィールドがparamsに使えるので、それを駆使して頑張る。
for i in range(100):
resp = requests.get(url, params={'per_page': 100, 'page': 1+i}, headers=headers)
for item in resp.json():
# Null文字があるとJSONBへの保存が失敗するので除去する
item['body'] = item['body'].strip('\x00')
item['rendered_body'] = item['rendered_body'].strip('\x00')
try:
QiitaItem.create_or_get(item_id=item['id'], created_at=item['created_at'], raw=item)
except DataError:
print(False)
Tweet2Vecを使う
まずは長さがまちまちであるQiitaのタイトルを、一定次元数のベクトル表現に変換したい。ベクトル空間モデルにしてしまえば様々な機械学習手法が適用できるようになる。
当初gensimのDoc2Vecを試したものの、イマイチ精度が出なかった。
色々論文を探した結果、Tweet2VecというRNN(GRU)を使った方法を発見。しかもGitHubにコードが公開されている。元論文では200万ツイートを2000種類のハッシュタグにクラス分類し、単語ベースの手法よりも高い精度を出したとのこと。
Qiitaの投稿のほとんどにはタグが複数付いており、文章の長さもツイートと同程度のため流用できると考えた。
Tweet2VecはTheano+Lasagneで実装されている(Python2.7)
データファイルの作成
100以上の投稿が存在する393種類のタグを使用。これらのタグが付いた投稿約63,000件のうち、5%をテスト、5%をCV用とし、残りの90%を学習用にした。
英語や中国語の投稿も混ざっているが、特に除外しない。元論文では英語のツイートのみを使用しており、大文字を小文字に変換する前処理が入っていたが、これも行わない。
訓練データ中に2281種類の文字が含まれており、BOWを使う場合より入力次元数は大幅に少ないと思われる。
from collections import Counter
from random import shuffle
import io
cnt = Counter()
for item in QiitaItem.select():
for tag in item.raw['tags']:
cnt[tag['name']] += 1
common_tags = [name for name, c in cnt.most_common(393)]
samples_all = []
for item in QiitaItem.select():
tags = [tag['name'] for tag in item.raw['tags']]
intersection = list(set(common_tags) & set(tags))
if len(intersection) > 1:
samples_all.append((item, intersection))
shuffle(samples_all)
n_all = len(samples_all)
n_test = n_cv = int(n_all / 20)
n_train = n_all - n_test - n_cv
samples_train = samples_all[:n_train]
samples_test = samples_all[n_train:(n_train+n_test)]
samples_cv = samples_all[(n_train+n_test):]
with io.open('misc/qiita_train.txt', 'w') as f:
for item, tags in samples_train:
for tag in tags:
f.write(tag + '\t' + item.raw['title'] + '\n')
with io.open('misc/qiita_test.txt', 'w') as f:
for item, tags in samples_test:
f.write(','.join(tags) + '\t' + item.raw['title'] + '\n')
with io.open('misc/qiita_cv.txt', 'w') as f:
for item, tags in samples_cv:
f.write(','.join(tags) + '\t' + item.raw['title'] + '\n')
実行
3種類のシェルスクリプトが提供されているので、中のファイルパスを書き換えて実行する。
学習させるとき
./tweet2vec_trainer.sh
精度を見たいとき
./tweet2vec_tester.sh
(正解のタグが不明なものに対して)タグを付けたい時
./tweet2vec_encoder.sh
注意点として、
- モデルの訓練に使ったタグの種類が少なすぎる場合、tweet2vec_encoder.sh実行時にエラーが出る。
- Qiitaタイトルの一部にタブや改行コードが含まれているらしく、除去するかTweet2Vecのコードを微修正する必要がある。
GPUインスタンスを使う
試しに訓練データファイル8,000行で学習させたところ、手元のMacBookで約2.7時間かかった。最終的な訓練データファイルは14万行ほどになっており、これをローカルマシンで学習させると完了まで48時間とかいうレベルで時間がかかりそうなため、AWSでGPUインスタンスを立てることにした。
- g2.2xlarge
- Ubuntu 14.04 64bit
- 節約のため米国東部リージョンのスポットインスタンスを使用
- ハイパーパラメータはTweet2Vecの初期設定からいじらず
この記事を参考にAWSのGPUインスタンスにTheanoをインストール。matplotlibがpipで入らず焦ったが、sudo apt-get install python-matplotlib
で入った。
途中で接続が切れると困るのでscreenして実行。その後狙ったかのようにターミナルがクラッシュしたがこのおかげで無事だった。
screen -S "qiita"
./tweet2vec_trainer.sh
約3.5時間、12 epochで終了。推定10倍以上のスピードアップなので課金した甲斐があった。
タグ予測結果
タグ予測の精度
テスト用データで精度72.10%、
訓練中一切参照されていないCVデータでも70.49%の精度を得た。
./tweet2vec_tester.sh (qiita_test.txt)
Precision @ 1 = 0.721040939384
Recall @ 10 = 0.777895906062
Mean rank = 4.29197080292
./tweet2vec_tester.sh (qiita_cv.txt)
Precision @ 1 = 0.704855601396
Recall @ 10 = 0.779064847138
Mean rank = 4.05744208188
元の論文でのTweetのハッシュタグ予測精度が24%程度だったので、題材が大きく異なるとはいえかなりの精度だと思われる。
以下、CV用データに対してタグを予測させてみた例を紹介する。
うまく予測できている例
# | 投稿タイトル | 実際のタグ | 予測されたタグ(TOP 10) |
---|---|---|---|
1 | Java プログラムを Maven から実行する方法 | Java,Maven | Maven,Java,java8,Eclipse,gradle,Android,Tomcat,Groovy,JavaEE,JUnit |
2 | JavaScriptで非同期無限ループ | JavaScript,Node.js | JavaScript,HTML,jQuery,Node.js,CSS,HTML5,CoffeeScript,js,Ajax,es6 |
3 | Macを買ってからRailsを開発するまでにやること | Mac,Rails | Ruby,Rails,Mac,Rails4,rbenv,RubyOnRails,MacOSX,Zsh,homebrew,Gem |
4 | 【Unity】ゲーム会社でスマホ向けゲームを開発して得た知識 UI編 | .NET,Unity3D,C#,Unity | Unity,C#,Unity3D,.NET,Android,iOS,LINQ,VisualStudio,android開発,Java |
5 | R言語 - ハイパフォーマンスコンピューティング | R,statistics | R,statistics,数学,数値計算,データ分析,Ruby,統計学,自然言語処理,Python,NLP R言語 |
6 | さくらVPSサーバーの初期設定とLAMP環境構築 | PHP,MacOSX,さくらVPS | さくらVPS,vps,Apache,CentOS,PHP,MySQL,Linux,CentOS6.x,WordPress,postfix |
7 | アニメーションでフレームアウトさせる(StoryBoard) | Xcode,Objective-C | iOS,Swift,Storyboard,Xcode,Objective-C,Xcode6,Android,UI,iOS8,iPhone |
8 | 出会い系アプリを作ってわかったAppleのリジェクト地獄へようこそ | Xcode,iOS | iOS,Swift,Xcode,Objective-C,Android,iPhone,CocoaPods,JavaScript,Mac,AdventCalendar |
9 | 半透明な要素内に透けない要素を表示 | HTML,CSS | CSS,HTML,HTML5,CSS3,JavaScript,jQuery,bootstrap,Android,js,Java |
10 | 用docker和golang做一个中文分词应用 | golang,docker | Go,golang,docker,Ruby,Slack,vagrant,Rails,GitHub,OSX,Erlang |
- 基本的に、タグ名称がそのままタイトル中に登場する場合は正しく上位候補に挙げることができている。
- 例1と2のように、JavaとJavaScriptを全く混同していない。単純な字面の類似度に惑わされず関連する概念を分類できている。当たり前のことのように見えるが、Doc2Vecを文字単位で適用した時はこれができていなかった。
- いわゆる「関連用語」からタグを予測できている。例6ではおそらく「LAMP」から「Linux」「Apache」「MySQL」「PHP」を予測しており、例9ではおそらく「要素」という用語からCSS関係の投稿であることを予測できている。
- 例10は中国語だが、惑わされることなくタグを当てている。
うまく予測できていない例
# | 投稿タイトル | 実際のタグ | 予測されたタグ(TOP 10) |
---|---|---|---|
1 | 全くの素人がISUCON5本戦に参加しました | golang,MySQL,Go,nginx | iOS,Unity,C#,Objective-C,JavaScript,Swift,Ruby,.NET,IoT,JSON |
2 | 食べられるキノコを見分ける | MachineLearning,Python,matplotlib | Linux,iOS,ShellScript,Ruby,Python,Objective-C,CentOS,PHP,Bash,Swift |
3 | gnome3で日本語入力の切替を右Altにする(AXキーボード風) | CentOS,Linux | JavaScript,Java,OSX,homebrew,Mac,api,HTML,Node.js,MacOSX,人工知能 |
4 | 静的ファイルがキャッシュされる(ブラウザキャッシュじゃない場合) | VirtualBox,Apache | JavaScript,IoT,CSS,Chrome,firefox,MacOSX,Windows,HTML5,jQuery,Arduino |
5 | 目的のハッシュタグを含むツイートをした人のユーザーネーム(@の後ろの部分)を取得する | Twitter,TwitterAPI,Gem,Ruby | Ruby,Rails,AWS,JavaScript,Python,Go,Java,jq,golang,PHP |
6 | viewでForm: could not find implicit value for parameter messages: play.api.i18n.Messages出るときの対処 | Scala,PlayFramework | Mac,MacOSX,OSX,Xcode,Android,Linux,Ruby,Ubuntu,Java,Windows |
- 当然ではあるが、タイトル中にヒントが少なすぎる、あるいは曖昧すぎる場合は予測がうまくいっていない。
- 例3の「gnome」、例5の「ハッシュタグ」「ツイート」などはヒントとなる用語であるが、おそらく訓練データが十分でないため何に関する用語なのか分からなかったと思われる。
類似投稿を出してみる
Tweet2Vecによって各投稿タイトルのベクトル表現が得られるので、コサイン距離を元に類似投稿を出力させてみる。
encode_char.pyを少し改造し、必用なファイルを出力させるようにする。
print("Encoding...")
out_data = []
out_pred = []
out_emb = []
numbatches = len(Xt)/N_BATCH + 1
for i in range(numbatches):
xr = Xt[N_BATCH*i:N_BATCH*(i+1)]
x, x_m = batch.prepare_data(xr, chardict, n_chars=n_char)
p = predict(x,x_m)
e = encode(x,x_m)
ranks = np.argsort(p)[:,::-1]
for idx, item in enumerate(xr):
out_data.append(item)
out_pred.append(' '.join([inverse_labeldict[r] for r in ranks[idx,:5]]))
out_emb.append(e[idx,:])
# Save
print("Saving...")
with open('%s/data.pkl'%save_path,'w') as f:
pkl.dump(out_data,f)
with io.open('%s/predicted_tags.txt'%save_path,'w') as f:
for item in out_pred:
f.write(item + '\n')
with open('%s/embeddings.npy'%save_path,'w') as f:
np.save(f,np.asarray(out_emb))
学習やテストに使わなかった投稿も含め、約13万件のQiita投稿を使用。
with io.open('../misc/qiita_all_titles.txt', 'w') as f:
for item in QiitaItem.select():
f.write(item.raw['title'] + '\n')
結果をqiita_result_allというディレクトリに出力させるように書き換えて実行。
./tweet2vec_encoder.sh
たかをくくってローカルマシンで実行したところ、約3時間かかった。これもGPUインスタンスを使えばよかった。
類似投稿を表示させるPythonコードはこんな感じ。
import cPickle as pkl
import numpy as np
from scipy import spatial
import random
with io.open('qiita_result_all/data.pkl', 'rb') as f:
titles = pkl.load(f)
with io.open('qiita_result_all/embeddings.npy','r') as f:
embeddings = np.load(f)
n = len(titles)
def most_similar(idx):
sims = [(i, spatial.distance.cosine(embeddings[idx],embeddings[i])) for i in range(n) if i != idx]
sorted_sims = sorted(sims, key=lambda sim: sim[1])
print titles[idx]
for sim in sorted_sims[:5]:
print "%.3f" % (1 - sim[1]), titles[sim[0]]
実行結果
左列の数値が類似度(1が最大、-1が最少)、右列がタイトル。
>>> most_similar(random.randint(0,n-1))
複数のGoogleMapとjqueryのwrap()を使ったときの謎現象
0.678 複数のGoogleカレンダーをまとめる
0.619 【GM】GoogleMapで検索した住所等の緯度経度を知りたい
0.601 さくっとGoogleMapsを表示する(API v3版)
0.596 スマートフォンブラウザで jQuery Sizzleと filterAPIを合わせて使った時の速度計測をしました
0.593 住所からGoogle Mapをサイトに表示する方法
>>> most_similar(random.randint(0,n-1))
scipyいれたいだけなのにこけたメモ。Ubuntu,Python3.
0.718 SciPyとmatplotlibのインストール(Python)
0.666 SciPy+matplotlibで3D散布図を作成(Python)
0.631 scipyはpython 2.7.8だとpip installでつまずく
0.622 【備忘録】future文 ~Python~
0.610 scipyとか使ってみる
>>> most_similar(random.randint(0,n-1))
複数行のUILabelのテキストを左上に寄せる
0.782 IBでUILabelのテキストを上寄せにする
0.699 NGUIのUILabelを一文字ずつ書き出す
0.624 SwiftでUILabelのテキストを中央寄せにする
0.624 IBからUILabelのtextに改行を入れる
0.624 2つのUITableViewのスクロールを同期させる
>>> most_similar(random.randint(0,n-1))
[Unity] プロジェクトを git や svn で管理するまえにやっておくこと
0.810 UnityでGit管理するときの設定項目
0.800 [Unity] git で管理する際の無視ファイル(.gitignore)設定
0.758 [Unity]ネットワークを通してメッセージをやり取りする簡単なサンプル
0.751 【Unity 基本操作】 マップ・ダンジョン制作時の便利なVキーやスナップ
0.746 [Unity] テキストファイルをスクリプトでロードして中身を表示する方法
>>> most_similar(random.randint(0,n-1))
Javaで回転行列を使った円の描画
0.911 java 回転行列を使用した線描画
0.898 Javaで整数とバイト配列の相互変換
0.897 Javaでフィボナッチ数列を表示
0.896 Java 2D 回転行列を使った円の描画
0.873 Javaの文字列結合効率化を確認
>>> most_similar(random.randint(0,n-1))
WWDC14で流れていた曲一覧
0.830 WWDC15に参加した話
0.809 WWDC 2014のメモ
0.797 WWDC2015をまとめてみた
0.772 WWDC2015で自分が聞いた質問まとめ
0.744 WWDC 2016の落選メールが当選メールだった話
- アルファベットの小文字・大文字を別の文字として学習させたにもかかわらず、「scipy」と「SciPy」、「Java」と「java」、「jQuery」と「jquery」、「Git」と「git」が同一概念であることを認識しているっぽい。
- 「GoogleMap」「GoogleMaps」「Google Map」も同じ概念であると認識できているっぽい。
- しかし「Java」と「JavaScript」、「Go」と「Google」は混同しない。
- 「WWDC」という、学習用の393種類のタグには含まれていないトピックについても、字面の類似度からか、近い投稿として取り出せている。
考察と課題
- 形態素解析せず、ユーザー辞書も作らず、前処理もほとんど行わずに学習させただけでこれだけの精度が出るのはものすごくありがたい。
- Qiita本文のような、さらに自然言語処理がしにくそうな文書にも適用できないか。
- 今回はTweet2Vecをそのまま使わせてもらっただけなので、中身をきっちり理解してより良いモデルを作りたい。
- 実プロダクトとして使うのであればハイパーパラメータもチューニングしたい。今回無視した各種メタ情報も組み合わせたい。