Perspective APIは、有害な文章を分類するためのAPIですが、そこで用いられるような教師データがkaggleのコンペ jigsaw toxic comment classificationとして公開されました。ここでは、word2vecを用いて試したいと思います。


Untitled drawing (4).jpg

  1. wikidumpなどを用いてword2vecのEmbeddingを作成(ここではgloveのpretrain modelを用いる。)
  2. 訓練データの文を単語に分割。
  3. 文ごとの単語のベクトル表現の平均ベクトルを特徴量とする。
  4. GBRTへ渡す。
  5. predict_probaで確率を算出する。


  1. kaggleからデータをダウンロード。
  2. gloveのpretrain modelを用意。
  3. anacondaのインストール。


import pandas as pd
import re
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.ensemble import GradientBoostingClassifier

train = pd.read_csv("train.csv")
test = pd.read_csv("test.csv")

target_columns = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

X = train[['comment_text']].as_matrix().tolist()
X_test = test[['comment_text']].as_matrix().tolist()

def fix_data(X):
    out = []
    for i, sentence in enumerate(X):
        tmp = str(sentence[0]).lower()
        regex = re.compile('[^a-zA-Z ]')
        tmp = re.sub(regex, "", tmp)
        tmp = tmp.replace('  ', ' ')
        out.append(tmp.split(' '))
    return out

X = fix_data(X)
X_test = fix_data(X_test)

test_ids = test[['id']].as_matrix().ravel()

with open("glove.6B.50d.txt", "rb") as lines:
    w2v = {line.split()[0]: np.array(line.split()[1:]).astype(float)
           for line in lines}

class MeanEmbeddingVectorizer(object):
    def __init__(self, word2vec):
        self.word2vec = word2vec
        self.dim = 50

    def fit(self, X, y):
        return self

    def transform(self, X):
        return np.array([
            np.mean([self.word2vec[bytes(w, 'ascii')] for w in words if bytes(w, 'ascii') in self.word2vec]
                    or [np.zeros(self.dim)], axis=0)
            for words in X

clf = Pipeline([
    ("vectorizer", MeanEmbeddingVectorizer(w2v)),
    ("gbrt", GradientBoostingClassifier(n_estimators=1000, verbose=True))])

out = {"id":test_ids}
for tc in target_columns:
    y = train[[tc]].as_matrix().ravel(), y)
    out[tc] = clf.predict_proba(X_test)[:, clf.classes_.tolist().index(1)]

pd.DataFrame(out).to_csv("submission2.csv", index=False)


Screenshot from 2017-12-29 07-17-32.png



Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.