背景
ディープラーニングではVGG19という画像分類によく使われる有名なアルゴリズムがあります。
このアルゴリズムは画像の特徴量抽出部分が他のアルゴリズムに比べて簡単な構成ではありますが、精度が良いため、画像の合成にも使われています。
Gatys,et al. Image Style Transfer Using Convolutional Neural Networks
https://pdfs.semanticscholar.org/7568/d13a82f7afa4be79f09c295940e48ec6db89.pdf
この論文を読んでいくなかで、VGG19を使ってできることは何だろうと考えていたところ、「ゆるキャラグランプリなんてものがあったな」と思い出したので、そのデータを収集してみることにしました。
この記事でやること
pythonを使ってデータをダウンロードしてきて、VGG19のアルゴリズムに乗せやすい形に変換します。
この記事でやらないこと
- VGG19とImage Style Transferの理論的な説明
- ダウンロードしたデータを使って実際に投票数を予想するアルゴリズム
実行環境
Linux Mint 17.1 Rebecca
python 3.5.4 (conda 4.3.30)
BeautifulSoup 4.6.0
scikit-image 0.13.1
処理フロー
- 投票結果のページから画像と順位、投票数が記載されているリンクのURLを取得
- 1で集めたリンク先からデータをダウンロード
- 上記1と2を1ページずつ行い、投票結果がそれ以上表示されなくなるまで繰り返す
- 集めた画像と投票数を紐付けてJSON形式に変換する。
URL収集するクラスを作成
次のURLから収集するデータのリンク先を集める部分を作ります。
http://www.yurugp.jp/vote/result_ranking.php?page=1&sort=1
集めるURLは、1例ですが、このような形です。
http://www.yurugp.jp/vote/detail.php?id=00000031
上のページでも順位、投票数、サムネイル画像等、必要なデータを集めることは可能です。
今回はこのデータセットを後々の事を考えて残しておきたいと考えていますので、大きい画像を取得するために下のURLからデータを取ってくることにします。
下が実際のコードです。
collectメソッドで指定されたページのURLを集めます。
クラス定数のURLはpageをfor文で変えられるようフォーマット文字列で定義しておきます。
ここで、sortはゆるキャラグランプリのカテゴリを示していて、1はご当地キャラクター、2は企業・その他のランキングです。これはSourceDataUrlCollectorのインスタンス作成時に指定させるようにしています。
最後に、収集するURLがこれ無い場合はエラーが起こるようになっています。
class SourceDataUrlCollector(object):
URL = 'http://www.yurugp.jp/vote/result_ranking.php?page={page}&sort={character_type_id}'
IMAGE_URL_BASE = 'http://www.yurugp.jp/vote/{}'
IMAGE_URL_PATTERN = re.compile(r'detail\.php\?id')
IMAGE_URL_ID_PATTERN = re.compile(r'detail\.php\?id\=(\d{8})')
CSS_TO_IMAGE_URL = '#charaList > ul.thumbnailList > li > a'
SORT_ID_GOTOCHI = 1
SORT_ID_COMPANY = 2
def __init__(self, character_type):
self.soup = None
check_character_type(character_type)
self.character_type = character_type
def get_character_type_id(self):
check_character_type(self.character_type)
if self.character_type == 'gotochi':
return self.SORT_ID_GOTOCHI
else:
return self.SORT_ID_COMPANY
def get_source_url(self, page):
self.src_url = self.URL.format(page=page, character_type_id=self.get_character_type_id())
return self.src_url
def collect(self, page):
target_urls = []
src_url = self.get_source_url(page)
with urllib.request.urlopen(src_url) as response:
html = response.read()
self.soup = BeautifulSoup(html, HTML_PARSER)
detail_url_src = self.soup.select(self.CSS_TO_IMAGE_URL)
for src in detail_url_src:
url = src.get('href')
if self.IMAGE_URL_PATTERN.match(url):
target_url = self.IMAGE_URL_BASE.format(url)
target_urls.append(target_url)
if not target_url:
raise ValueError('No more data.')
return target_urls
データを実際にダウンロードするクラスを作成
データをダウンロードするクラスを作ります。
SourceDataUrlCollectorで収集したURLをDataDownloaderクラスのdownloadメソッドに渡してダウンロードを実行します。
画像のダウンロードと一緒に次の8項目からなるCSVファイルも保存されます。
entry_no: エントリーNo.
character_type: ランキングのカテゴリ(1: ご当地、2: 企業・その他)
character_name: キャラクターの名前
prefecture: 出身都道府県
ranking: 投票数の順位
point: 投票数
url: 画像のURL
filename: 画像をローカルに保存する際のファイル名
downloadメソッドは、サーバーの負荷を考慮して、受け取ったURLを1つ処理する度にSLEEP_TIME_SECで指定された秒数だけ待ちます。
私は5秒で設定しましたが、ご当地と企業・その他すべてのデータを取得するのに1時間以上掛かりました。
最後にcloseメソッドを呼び出して書きだしたデータをflushさせます。
これを忘れるとCSVデータが意図通りに書き出されない可能性があるので注意してください。
class DataDownloader(object):
CSS_RANK_POINT = '#detail > div.ttl_entry > ul > li.rank_pt'
CSS_CHARACTER_NAME = '#detail > div.ttl_entry > h3'
CSS_ENTRY_NO_PREFECTURE = '#detail > div.ttl_entry > ul > li.entry_no'
CSS_MAIN_IMG = '#detail > div.mainImg > img'
RANK_PATTERN = re.compile('^(\d+)位 / \d+pt$')
POINT_PATTERN = re.compile('^\d+位 / (\d+)pt$')
ENTRY_NO_PATTERN = re.compile(r'^エントリーNo.(\d+)(.+)$')
PREFECTURE_PATTERN = re.compile(r'^エントリーNo.\d+((.+))$')
def __init__(self, work_dir, csv_file, character_type):
self.work_dir = work_dir
check_character_type(character_type)
self.character_type = character_type
self.image_dir = os.path.join(work_dir, character_type, 'image')
self.csv_file = csv_file
self.urls = None
self.fout = open(self.csv_file, 'w')
self.writer = csv.writer(self.fout, lineterminator='\n')
header = ['entry_no', 'character_type', 'character_name', 'prefecture', 'ranking',
'point', 'url', 'filename']
self.writer.writerow(header)
def extract_rank(self):
rank_point_src = self.soup.select(self.CSS_RANK_POINT)[0].text
ranking = self.RANK_PATTERN.findall(rank_point_src)[0]
return ranking
def extract_point(self):
rank_point_src = self.soup.select(self.CSS_RANK_POINT)[0].text
point = self.POINT_PATTERN.findall(rank_point_src)[0]
return point
def extract_character_name(self):
character_name = self.soup.select(self.CSS_CHARACTER_NAME)[0].text
return character_name
def extract_entry_no(self):
entry_no_prefecture_src = self.soup.select(self.CSS_ENTRY_NO_PREFECTURE)[0].text
entry_no = self.ENTRY_NO_PATTERN.findall(entry_no_prefecture_src)[0]
return entry_no
def extract_prefecture(self):
entry_no_prefecture_src = self.soup.select(self.CSS_ENTRY_NO_PREFECTURE)[0].text
prefecture = self.PREFECTURE_PATTERN.findall(entry_no_prefecture_src)[0]
return prefecture
def download_image(self, output_dir):
url_main_img = self.soup.select(self.CSS_MAIN_IMG)[0].get('src')
filename = os.path.basename(url_main_img)
urllib.request.urlretrieve(url_main_img, os.path.join(output_dir, filename))
return (url_main_img, filename)
def download(self, urls):
if not isinstance(urls, list):
raise ValueError('urls must be list.')
self.urls = urls
for url in self.urls:
with urllib.request.urlopen(url) as response:
html = response.read()
self.soup = BeautifulSoup(html, HTML_PARSER)
entry_no = self.extract_entry_no()
character_type = self.character_type
character_name = self.extract_character_name()
prefecture = self.extract_prefecture()
ranking = self.extract_rank()
point = self.extract_point()
url, filename = self.download_image(output_dir=self.image_dir)
data = [entry_no, character_type, character_name, prefecture,
ranking, point, url, filename]
self.writer.writerow(data)
time.sleep(SLEEP_TIME_SEC)
def close(self):
self.fout.flush()
self.fout.close()
実際にデータをダウンロードする
上の2つのクラスを組み合わせたスクリプトは下のとおりです。
標準モジュールのargparseでコマンドラインからワーキングディレクトリとランキングタイプ(ご当地、または、企業・その他)を与えられるようにしてます。
# -*- coding: utf-8 -*-
import argparse
import csv
import os
import logging
import re
import sys
import time
import urllib
from bs4 import BeautifulSoup
MAX_PAGE = 40
SLEEP_TIME_SEC = 5
HTML_PARSER = 'lxml'
LOG_DIR = '/home/ishiyama/yuruchara'
logging.basicConfig(
format='%(asctime)s %(message)s',
filename=os.path.join(LOG_DIR, 'download_images.log'),
level=logging.DEBUG
)
def check_character_type(character_type):
if character_type not in ['gotochi', 'company']:
raise ValueError('character_type must be "gotochi" or "company"')
class SourceDataUrlCollector(object):
URL = 'http://www.yurugp.jp/vote/result_ranking.php?page={page}&sort={character_type_id}'
IMAGE_URL_BASE = 'http://www.yurugp.jp/vote/{}'
IMAGE_URL_PATTERN = re.compile(r'detail\.php\?id')
IMAGE_URL_ID_PATTERN = re.compile(r'detail\.php\?id\=(\d{8})')
CSS_TO_IMAGE_URL = '#charaList > ul.thumbnailList > li > a'
SORT_ID_GOTOCHI = 1
SORT_ID_COMPANY = 2
def __init__(self, character_type):
self.has_data = False
self.soup = None
check_character_type(character_type)
self.character_type = character_type
def get_character_type_id(self):
check_character_type(self.character_type)
if self.character_type == 'gotochi':
return self.SORT_ID_GOTOCHI
else:
return self.SORT_ID_COMPANY
def get_source_url(self, page):
self.src_url = self.URL.format(page=page, character_type_id=self.get_character_type_id())
return self.src_url
def collect(self, page):
target_urls = []
src_url = self.get_source_url(page)
logging.info('SourceDataUrlCollector.download, {}'.format(src_url))
with urllib.request.urlopen(src_url) as response:
html = response.read()
self.soup = BeautifulSoup(html, HTML_PARSER)
detail_url_src = self.soup.select(self.CSS_TO_IMAGE_URL)
for src in detail_url_src:
url = src.get('href')
logging.info('SourceDataUrlCollector.download: Current URL {}'.format(url))
if self.IMAGE_URL_PATTERN.match(url):
logging.info('SourceDataUrlCollector.download: Include: {}'.format(url))
target_url = self.IMAGE_URL_BASE.format(url)
target_urls.append(target_url)
logging.info('SourceDataUrlCollector.download: Success: {}'.format(target_url))
else:
logging.info('SourceDataUrlCollector.download: Exclude: {}'.format(url))
if not target_url:
logging.info('No more data.')
raise ValueError('No more data.')
return target_urls
class DataDownloader(object):
CSS_RANK_POINT = '#detail > div.ttl_entry > ul > li.rank_pt'
CSS_CHARACTER_NAME = '#detail > div.ttl_entry > h3'
CSS_ENTRY_NO_PREFECTURE = '#detail > div.ttl_entry > ul > li.entry_no'
CSS_MAIN_IMG = '#detail > div.mainImg > img'
RANK_PATTERN = re.compile('^(\d+)位 / \d+pt$')
POINT_PATTERN = re.compile('^\d+位 / (\d+)pt$')
ENTRY_NO_PATTERN = re.compile(r'^エントリーNo.(\d+)(.+)$')
PREFECTURE_PATTERN = re.compile(r'^エントリーNo.\d+((.+))$')
def __init__(self, work_dir, csv_file, character_type):
self.work_dir = work_dir
check_character_type(character_type)
self.character_type = character_type
self.image_dir = os.path.join(work_dir, character_type, 'image')
self.csv_file = csv_file
self.urls = None
self.fout = open(self.csv_file, 'w')
self.writer = csv.writer(self.fout, lineterminator='\n')
header = ['entry_no', 'character_type', 'character_name', 'prefecture', 'ranking',
'point', 'url', 'filename']
self.writer.writerow(header)
def extract_rank(self):
rank_point_src = self.soup.select(self.CSS_RANK_POINT)[0].text
ranking = self.RANK_PATTERN.findall(rank_point_src)[0]
return ranking
def extract_point(self):
rank_point_src = self.soup.select(self.CSS_RANK_POINT)[0].text
point = self.POINT_PATTERN.findall(rank_point_src)[0]
return point
def extract_character_name(self):
character_name = self.soup.select(self.CSS_CHARACTER_NAME)[0].text
return character_name
def extract_entry_no(self):
entry_no_prefecture_src = self.soup.select(self.CSS_ENTRY_NO_PREFECTURE)[0].text
entry_no = self.ENTRY_NO_PATTERN.findall(entry_no_prefecture_src)[0]
return entry_no
def extract_prefecture(self):
entry_no_prefecture_src = self.soup.select(self.CSS_ENTRY_NO_PREFECTURE)[0].text
prefecture = self.PREFECTURE_PATTERN.findall(entry_no_prefecture_src)[0]
return prefecture
def download_image(self, output_dir):
url_main_img = self.soup.select(self.CSS_MAIN_IMG)[0].get('src')
filename = os.path.basename(url_main_img)
urllib.request.urlretrieve(url_main_img, os.path.join(output_dir, filename))
return (url_main_img, filename)
def download(self, urls):
if not isinstance(urls, list):
raise ValueError('urls must be list.')
self.urls = urls
for url in self.urls:
logging.info('DataDownloader.download, {}'.format(url))
with urllib.request.urlopen(url) as response:
html = response.read()
self.soup = BeautifulSoup(html, HTML_PARSER)
entry_no = self.extract_entry_no()
character_type = self.character_type
character_name = self.extract_character_name()
prefecture = self.extract_prefecture()
ranking = self.extract_rank()
point = self.extract_point()
url, filename = self.download_image(output_dir=self.image_dir)
data = [entry_no, character_type, character_name, prefecture,
ranking, point, url, filename]
self.writer.writerow(data)
time.sleep(SLEEP_TIME_SEC)
def close(self):
self.fout.flush()
self.fout.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('work_dir', type=str)
parser.add_argument('character_type', type=str, choices=('gotochi', 'company'))
args = parser.parse_args()
work_dir = args.work_dir
character_type = args.character_type
source_url_collector = SourceDataUrlCollector(character_type=character_type)
downloader = DataDownloader(
work_dir=work_dir,
csv_file=os.path.join(work_dir, character_type, 'yuruchara.csv'),
character_type=character_type
)
for page in range(1, MAX_PAGE + 1):
try:
src_url = source_url_collector.collect(page)
except:
logging.info('Exit.')
downloader.close()
sys.exit(1)
time.sleep(SLEEP_TIME_SEC)
downloader.download(urls=src_url)
downloader.close()
保存先のディレクトリ構成です。
/home/ishiyama/yuruchara/
├── company
│ └── image
└── gotochi
└── image
このスクリプトを次のシェルスクリプトにまとめて実行しました。
#!/bin/sh
work_dir="お好みのディレクトリ"
# ここではdownload_images.pyとワーキングディレクトリを同じにしていますが、
# この通りにする必要はありません。
python ${work_dir}/download_images.py ${work_dir} company
python ${work_dir}/download_images.py ${work_dir} gotochi
# (オプション)データセットを保管しておきたいのでtar.gzで固める
tar -zcvf yuruchara.tar.gz ${work_dir}
実行後はこのようにファイルが出力されています。
/home/ishiyama/yuruchara_test/
├── company
│ ├── image
│ │ ├── 00000041.jpg
│ │ ├── ...
│ │ └── 00003712.jpg
│ └── yuruchara.csv
├── download_images.log
├── download_images.py
├── downloader.sh
└── gotochi
├── image
│ ├── 00000031.jpg
│ ├── ...
│ └── 00002537.jpg
└── yuruchara.csv
VGG19向けにダウンロードした画像を加工する
元の論文を見る限りはどのサイズでも大丈夫なように見えますが、GitHubで公開されている他の人の実装を見てみますと、誰もそこにはチャレンジしていないので、安全策で画像を225ピクセル×225ピクセルに変換します。
また、保存形式はGoogle Cloud MLで使いやすいようにJSONで無圧縮にします。
ご当地と企業・その他の変換後のデータは合計で1.3GBを超えますので本来だったら圧縮したいところですが、GCPのセキュリティ上の制約でpythonからzipを解凍する関数が使えません。
この面倒を回避したいので、生のJSONにしておくことにします。
(もし他に良い方法をお持ちの方がいらっしゃいましたら、是非ご教示いただけますでしょうか)
変換の手順ですが、今回取得した画像のサイズは横650×縦790なので
- 縦が650になるように上から650ピクセルを切り出す
- 225×225になるようにリサイズ
で進めていきます。
スクリプトは次の通りです。
このスクリプトではGIF画像を省いています。これはGIFをJPEGと同じように扱うことができなかったためです。
やり方はあるようですが、今回はそこまで深追いせずにJPEGのみを対象にすることにしました。
# -*- coding: utf-8 -*-
import csv
import json
import os
import numpy as np
import skimage
from skimage.io import imread
from skimage.transform import resize
WORK_DIR = os.getcwd()
CROP_POSITION_LEFT = 0
CROP_POSITION_UPPER = 0
CROP_POSITION_RIGHT = 650
CROP_POSITION_BOTTOM = 650
RESIZE_WIDTH = 225
RESIZE_HEIGHT = 225
ENCODING = 'UTF-8'
def get_src_image_path(character_type, filename):
return os.path.join(WORK_DIR, character_type, 'src', 'image', filename)
def get_src_data_path(character_type):
return os.path.join(WORK_DIR, character_type, 'yuruchara.csv')
def get_output_data_path(character_type):
return os.path.join(WORK_DIR, character_type, '{}.json'.format(character_type))
def check_gif(filename):
extension = os.path.splitext(filename)[1]
return (extension == '.gif')
def make_json(character_type):
train_image_list = []
point_list = []
with open(get_src_data_path(character_type), 'r', encoding=ENCODING) as f:
reader = csv.DictReader(f)
for record in reader:
# gif画像はresizeできないのでスキップする
if check_gif(record['filename']):
continue
src_image = imread(get_src_image_path(character_type, record['filename']))
cropped_image = src_image[CROP_POSITION_LEFT:CROP_POSITION_RIGHT, CROP_POSITION_UPPER:CROP_POSITION_BOTTOM, :]
training_image = resize(image=cropped_image, output_shape=(RESIZE_HEIGHT, RESIZE_WIDTH))
train_image_list.append(training_image.tolist())
point_list.append(record['point'])
return {'supervise': point_list, 'train': train_image_list}
if __name__ == '__main__':
for character_type in ['company', 'gotochi']:
with open(get_output_data_path(character_type), 'w', encoding=ENCODING) as f:
json.dump(make_json(character_type), f)
こちらはターミナルから下記のように実行するだけです。
(download_images.pyとの統一感に欠けていて申し訳ありません...)
python image_converter.py
以上でVGG19に使えるデータができました。
まとめ
ゆるキャラグランプリのご当地と企業・その他のランキングから画像や投票数、順位等のデータを取得するpythonスクリプトについて解説しました。
今後は実際にVGGを使って投票数を予想するロジックと実装について書きたいと思います。
参考文献
Gatys,et al. Image Style Transfer Using Convolutional Neural Networks
https://pdfs.semanticscholar.org/7568/d13a82f7afa4be79f09c295940e48ec6db89.pdf