Help us understand the problem. What is going on with this article?

Webコンテンツ抽出のvision-based CNN (自作)

More than 1 year has passed since last update.

Webコンテンツ抽出のvision-based手法とは、Webページのスクリーンショットを解析し、コンテンツ抽出の特徴量として使う手法です。TextMapsというオープンソースの手法もありますが、今回はモデル自体を自作します。

事前準備

Webページのスクリーンショットとdomツリーを取得する

この部分は本質ではないので詳細は省きますが、以下の仕様を満たすCLIツールを作成してください。

"urlを指定すると、そのurlの「スクリーンショット」と「スクリーンショット画像内のdom要素の位置を保存したdomツリー」を出力する。"

TextMapsというプロジェクト内には、download_page.jsという名前でphantomjsのスクリプトが置かれています。

訓練データの収集とアノテーション

前述のツールを使ってWeb記事を収集します。取得されたスクリーンショットとdomツリーは、pascal voc形式に変換し、labelImgで読み込むことでアノテーションします。

以下を参考にしてください。
https://qiita.com/sugiyamath/items/968463b26c0b9b0d0c40

アノテーション時のルールは以下を採用します。
1. アノテーションの候補要素は、dom.json内のpositionをもつすべての要素。
2. ニュースやブログの記事を対象とする。
3. 記事内のコンテンツ(本文)全体を含み、かつ余分な要素をできるだけ含まない候補位置を1つだけ選び、ラベル付けする。

ルール3により、アノテーション作業が簡素化され、効率的に作業ができます。

データの切り出し

アノテーション済みデータから、以下のルールでデータを切り出します。

  1. 幅と高さが500以下の要素を除外。
  2. アノテーションで"difficult"ラベルがつけられていたら画像自体を除外。
  3. ラベル付けした1つの要素以外はFalse, ラベル付けした要素はTrueとしてラベルを定義。
  4. dom要素のポジションとラベルを辞書として保存。

以下は辞書の形式です。

targets = {
    "ps":[ターゲット画像の切り出されたポジション一覧], 
    "labels": [ターゲット画像の切り出されたラベル一覧]
}

assert len(targets['ps']) == len(targets['labels'])

この辞書をpkl形式で保存しておきます。

2つのモジュールを作成

モジュール1: bindetector.py

import numpy as np
from skimage.transform import resize
from skimage import io

def load_image(image_file, p, size=(905,905)):
    im = io.imread(image_file)
    alpha = np.zeros([im.shape[0], im.shape[1]])
    alpha[p[1]:p[3], p[0]:p[2]] = 255.
    im = np.dstack((im, alpha))
    im = resize(im, size)/255.0
    return im

モジュール2: dataprocessor.py

from bs4 import BeautifulSoup
import bindetector
import numpy as np
import pickle
import os
import random
from sklearn.utils import shuffle


def difficult_check(soup):
    for o in soup.find_all("object"):
        if int(o.find('difficult').text) == 1:
            return True
    return False


def extract_positions(soup, min_width=500, min_height=500):
    targets = {'ps':[], 'labels':[]}
    for o in soup.find_all("object"):
        xmax = int(o.find('xmax').text)
        xmin = int(o.find('xmin').text)
        ymax = int(o.find('ymax').text)
        ymin = int(o.find('ymin').text)
        width = xmax - xmin
        height = ymax - ymin
        if width < min_width or height < min_height:
            continue
        else:
            targets['ps'].append([xmin, ymin, xmax, ymax])
            targets['labels'].append(o.find('name').text == "content")
    return targets


def data_preparation(keyname, rootpath):
    path = os.path.join(rootpath,keyname)
    files = [f.split(".")[0] for f in os.listdir(path) if f.endswith("xml")]
    return path, files


def get_and_check_targets(soup):
    if difficult_check(soup):
        return False, None
    targets = extract_positions(soup)
    if sum(targets['labels']) != 1:
        return False, None
    return (True, targets)


def transform_img(index, ps, size=(905, 905), data_path="data", img_format="jpeg", batch_path="batch"):
    X = []
    keyname, fileid = index.split("_")[:2]
    pic = "{}.{}".format(fileid.split(".")[0], img_format)
    path = os.path.join(data_path, keyname)
    for p in ps:
        X.append(bindetector.load_image(os.path.join(path, pic), p))
    return X


def sampling(targets, sample_size=4, index=False, batch_path="batch"):
    assert len(targets['ps']) == len(targets['labels'])
    if index is not False:
        npy_file = os.path.join(batch_path, index.split(".")[0]+".npy")
        if os.path.isfile(npy_file):
            imgs = np.load(npy_file, mmap_mode='r')
            data = [img for img, label in zip(imgs, targets['labels']) if label is True]
            assert len(data) == 1
            tmp_targets = shuffle(list(zip(imgs, targets['labels'])))
            data += [img for img, label in tmp_targets if label is False][:sample_size]
            return data
    data = [p for p, label in zip(targets['ps'], targets['labels']) if label is True]
    assert len(data) == 1
    tmp_targets = shuffle(list(zip(targets['ps'], targets['labels'])))
    data += [p for p, label in tmp_targets if label is False][:sample_size]
    return data


def get_indices(batch_path="batch"):
    return [f for f in os.listdir(batch_path) if f.endswith("pkl")]


def generate_data(indices, data_path="data", batch_path="batch", batch_size=5, npy_exists=False):
    while(True):
        X = []
        labels = []
        for index in shuffle(indices)[:batch_size]:
            with open(os.path.join(batch_path, index), "rb") as f:
                targets = pickle.load(f)
            if npy_exists:
                data = sampling(targets, index=index, batch_path=batch_path)
            else:
                data = sampling(targets, index=False, batch_path=batch_path)
            label = [True] + [False for _ in data[1:]]
            labels += label
            assert len(data) == len(label)
            if npy_exists:
                X += data
            else:
                X += transform_img(index, data, data_path=data_path, batch_path=batch_path)
        yield np.array(X), np.array(labels)

CNNモデルの概要

Untitled drawing (2).jpg

入力: WebページのスクリーンショットのRGBに対し、「候補dom要素の位置を255、それ以外を0としたalpha値」を加え、全体を255で割ってリサイズしたもの。

出力: 各dom要素に対し、「その要素がコンテンツを含む確率」を出力。

モデル: Sequential CNN

jupyter notebookで実行

最初に、検証データを切り出します。

import dataprocessor as dp
indices = dp.get_indices()

for data in dp.generate_data(indices[0:50], batch_size=50, npy_exists=False):
    eval_data = data
    break

for data in dp.generate_data(indices[50:150], batch_size=100, npy_exists=False):
    eval_data2 = data
    break

fit_generatorに渡すジェネレータをdataprocessorから定義します。

from functools import partial
generator_batch = partial(dp.generate_data, indices=indices[150:], batch_size=1)

モデルを定義します。

from keras.models import Sequential
from keras.layers import Activation, Dropout, Flatten, Dense, SeparableConv2D, MaxPooling2D
from keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
from sklearn.metrics import accuracy_score, f1_score


model = Sequential()
model.add(SeparableConv2D(32, (3, 3), input_shape=(905, 905, 4)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(SeparableConv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(SeparableConv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

mcp_save = ModelCheckpoint('.mdl_wts.h5', save_best_only=True, monitor='val_loss', mode='min')

訓練します。

model.fit_generator(
    generator_batch(), 
    steps_per_epoch=100,
    epochs=50,
    validation_data=eval_data,
    callbacks=[mcp_save]
)

テストデータで精度を見ます。(テストデータ件数は500件)

preds = []
for i in range(50):
    preds += model.predict(eval_data2[0][i*10:i*10+10]).tolist()

results = []
prev = 0
for i, (p, e) in enumerate(zip(preds, eval_data2[1])):
    if (e == True and i!=0) or i == 499:
        if i==499:
            i=500
        results.append([x[0] for x in preds[prev:i]])
        prev = i

tmp_pred = [np.argmax(result) for result in results]

pred_labels1 = []
for x, r in zip(tmp_pred, results):
    for i in range(len(r)):
        if i == x:
            pred_labels1.append(True)
        else:
            pred_labels1.append(False)

assert len(pred_labels1) == len(eval_data2[1])
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
print(roc_auc_score(eval_data2[1], pred_labels1))
print(classification_report(eval_data2[1], pred_labels1))

精度の出力:

0.8874999999999998

             precision    recall  f1-score   support

      False       0.95      0.95      0.95       400
       True       0.82      0.82      0.82       100

avg / total       0.93      0.93      0.93       500

補足

予測の際には、「候補要素の予測値の中で最大のものをTrue, ほかをFalseにする」という処理を行っています。

リンク

sugiyamath/information_extraction_exxperiments - https://github.com/sugiyamath/information_extraction_experiments/tree/master/model3

sugiyamath
名もなき小企業の名もなきプログラマー。
http://datanerd.hateblo.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした