6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

アニメ・イラストの顔を集めたデータセットを作りました

Posted at

データセットの内容

ちょっと前までGANにハマっていて、その時に作った遺産です。
kaggleのページに公開しました。
もし著作権的にアウトであれば、削除します。

内容はSafebooruからスクレイピングし、顔部分をトリミングした約500,000枚の画像です。
画像のサイズは $256 \times 256$ 以上で、揃えていません(できるだけ高解像度にしてあります)。

他のデータセットと違う特徴として、目の位置を検出することで、顔の位置や角度を揃えてあります。
このことでGANなどで学習を行うとき、安定しやすくなっているのではないかと思います。
また、誤検出などで生じた、顔以外の画像、位置のずれた画像などはSVMで学習させて弾いています。

顔の検出は有名なlbpcascade_animeface.xmlを使用しました。
目の検出には自作の検出器を使っています(昔公開したOpenCVの物体検出の訓練データを作るためのツールを使いました)。
Safebooruからのスクレイピングには参考にしたページがあった気がするのですが、忘れてしまいました。

コード

スクレイピング

scraping.py
import urllib.request
import xml.etree.ElementTree
import cv2
import os
import numpy
import math
import PIL.Image
import io
import concurrent.futures
import time

minimum_size = 256
target_eye_distance = 0.15

output_directory = 'D:/safebooru_face/'


def reject_image(image):
    image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
    reshaped_image = image.reshape(-1, 3)
    cov_image = numpy.cov(reshaped_image, rowvar=False)
    eigen_values = numpy.linalg.eigvals(cov_image)
    eigen_values = numpy.sort(eigen_values)

    is_monotone = eigen_values[2] > eigen_values[1] * 1000
    return is_monotone


def trim_image(image, face_classifier, eye_classifier):
    result_images = []

    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray_image = cv2.equalizeHist(gray_image)

    face_rects = face_classifier.detectMultiScale(gray_image,
                                                  scaleFactor=1.02,
                                                  minNeighbors=2,
                                                  minSize=(minimum_size, minimum_size))

    for (x, y, w, h) in face_rects:
        center_x = x + w / 2
        center_y = y + h / 2
        trim_size = max(w, h)

        x_left = int(center_x - trim_size / 2)
        x_right = int(center_x + trim_size / 2)
        y_top = int(center_y - trim_size / 2)
        y_bottom = int(center_y + trim_size / 2)

        gray_face_roi = gray_image[y_top:y_bottom, x_left:x_right]

        for min_neighbors in [10, 8, 6, 5, 4, 3, 2, 1, 0]:
            eye_rects = eye_classifier.detectMultiScale(gray_face_roi,
                                                        scaleFactor=1.02,
                                                        minNeighbors=min_neighbors,
                                                        minSize=(trim_size // 32, trim_size // 32))

            if len(eye_rects) == 2:
                break
        else:
            continue

        eye_positions = []
        for (x, y, w, h) in eye_rects:
            eye_x = x_left + x + w / 2
            eye_y = y_top + y + h / 2

            if x_left < eye_x and eye_x < x_right and y_top < eye_y and eye_y < y_bottom:
                eye_positions.append(numpy.array([eye_x, eye_y]))

        left_eye = eye_positions[0] if eye_positions[0][0] < eye_positions[1][0] else eye_positions[1]
        right_eye = eye_positions[1] if eye_positions[0][0] < eye_positions[1][0] else eye_positions[0]

        eye_center = (left_eye + right_eye) / 2
        face_angle = math.degrees(math.atan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0]))
        eye_distance = math.sqrt(numpy.sum((left_eye - right_eye) ** 2))

        if eye_distance < 5:
            continue

        scale = (target_eye_distance * trim_size) / eye_distance
        trim_size /= scale

        if trim_size < minimum_size:
            continue

        rotation_matrix = cv2.getRotationMatrix2D((eye_center[0], eye_center[1]), face_angle, 1.0)
        rotation_matrix[0, 2] += trim_size / 2 - eye_center[0]
        rotation_matrix[1, 2] += trim_size / 2 - eye_center[1]

        trimmed_image = cv2.warpAffine(image, rotation_matrix, (int(trim_size), int(trim_size)), flags=cv2.INTER_LANCZOS4, borderValue=(255, 255, 255))

        result_images.append(trimmed_image)

    return result_images


def download_image(file_url):
    try:
        bin_image = io.BytesIO(urllib.request.urlopen(file_url).read())
        pil_image = PIL.Image.open(bin_image)

        image = numpy.array(pil_image, dtype=numpy.uint8)

        if image.shape[2] == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        elif image.shape[2] == 3:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
        else:
            return

        return image, file_url
    except Exception as e:
        print(e)


if __name__ == '__main__':
    os.makedirs(output_directory, exist_ok=True)

    tags = [
        '-grayscale', '-greyscale', '-no_humans', '-glasses', '-furry'
    ]

    face_classifier = cv2.CascadeClassifier('lbpcascade_animeface.xml')
    eye_classifier = cv2.CascadeClassifier('lbpcascade_animeeye4.xml')

    for page_number in range(100000):
        print('Page : {}'.format(page_number))

        url = 'https://safebooru.org/index.php?page=dapi&s=post&q=index&pid={pid}&tags={tags}'.format(
            pid=page_number,
            tags='+'.join(tags)
        )

        while True:
            try:
                # print('Request.')
                with urllib.request.urlopen(url) as response:
                    content = response.read()
                break
            except Exception as e:
                print('Failed.')
                print(e)
                time.sleep(1.0)
                continue

        print('Success!')

        posts_xml = xml.etree.ElementTree.fromstring(content)

        print('n_posts : {}'.format(len([post for post in posts_xml.iter('post')])))

        file_urls = []

        for post in posts_xml.iter('post'):
            if int(post.get('width')) < minimum_size or int(post.get('height')) < minimum_size:
                continue

            if int(post.get('width')) * int(post.get('height')) > 3000 * 3000:
                continue

            file_url = post.get('file_url')
            extention = os.path.splitext(os.path.basename(file_url))[1]

            if extention == '.gif':
                continue

            file_urls.append(file_url)

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(download_image, file_url) for file_url in file_urls]

            for future in futures:
                try:
                    image, file_url = future.result()
                    file_name = os.path.splitext(os.path.basename(file_url))[0]

                    if reject_image(image):
                        continue

                    trimmed_images = trim_image(image, face_classifier, eye_classifier)

                    for image_counter, trimmed_image in enumerate(trimmed_images):
                        output_path = os.path.join(output_directory, file_name + '_' + str(image_counter) + '.png')
                        cv2.imwrite(output_path, trimmed_image)

                        print('{} -> {}'.format(file_url, output_path))

                except Exception as e:
                    print(e)

SVMによる誤検出の検出

svm.py
import cv2
import glob
import numpy
import sklearn.ensemble
import sklearn.svm
import skimage.feature
import sklearn.pipeline
import sklearn.decomposition
import sklearn.preprocessing
import os
import shutil
import pickle
import traceback


def split_list(input_list, n):
    return [input_list[i:i + n] for i in range(0, len(input_list), n)]


def preprocess(file_name):
    image = cv2.imread(file_name)

    if image is None:
        raise OSError('File is not found.')

    image = image.astype(numpy.float)
    image_size = min(image.shape[0], image.shape[1])

    left = image.shape[0] // 2 - image_size // 4
    right = image.shape[0] // 2 + image_size // 4
    top = image.shape[1] // 2 - image_size // 4
    bottom = image.shape[1] // 2 + image_size // 4
    image = image[left:right, top:bottom, :]

    image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)

    hog = skimage.feature.hog(image, pixels_per_cell=(32, 32)).ravel()
    small_image = cv2.resize(image, (8, 8), interpolation=cv2.INTER_AREA).ravel()
    return numpy.concatenate([hog, small_image])


if __name__ == '__main__':
    true_dir = 'D:/safebooru_class/ok_candidate'
    false_dir = 'D:/safebooru_class/ng_candidate'
    input_dir = 'D:/safebooru_face'
    ok_dir = 'D:/safebooru_class/ok'
    ng_dir = 'D:/safebooru_class/ng'
    ambiguous_dir = 'D:/safebooru_class/ambiguous'
    dataset_file = 'D:/safebooru_class/dataset'

    ok_files = glob.glob(ok_dir + '/*')  # + glob.glob(true_dir + '/*')
    ng_files = glob.glob(ng_dir + '/*')  # + glob.glob(false_dir + '/*')
    # input_files = glob.glob(input_dir + '/**', recursive=True)

    all_files = ok_files + ng_files

    if os.path.exists(dataset_file):
        with open(dataset_file, 'rb') as f:
            dataset = pickle.load(f)
    else:
        dataset = {}

    dataset_changed = False

    for i, file in enumerate(all_files):
        if os.path.basename(file) not in dataset.keys():
            print(f'Process : {file} ({i} / {len(all_files)})')

            try:
                dataset[os.path.basename(file)] = preprocess(file)
                dataset_changed = True
            except OSError:
                continue

    if dataset_changed:
        print('Writing Dataset...')
        with open(dataset_file, 'wb') as f:
            pickle.dump(dataset, f)
        print('Done.')

    for iter_counter in range(1):
        ok_files = glob.glob(ok_dir + '/*')  # + glob.glob(true_dir + '/*')
        ng_files = glob.glob(ng_dir + '/*')  # + glob.glob(false_dir + '/*')
        input_files = glob.glob(input_dir + '/**', recursive=True)

        if len(input_files) < 100:
            break

        x = []
        y = []

        x.extend([dataset[os.path.basename(file)] for file in ok_files])
        y.extend([1] * len(ok_files))

        print('End OK Preprocess.')

        x.extend([dataset[os.path.basename(file)] for file in ng_files])
        y.extend([0] * len(ng_files))

        print('End NG Preprocess.')

        x = numpy.stack(x, axis=0)
        y = numpy.stack(y, axis=0)

        # classifier = sklearn.ensemble.VotingClassifier([
        #     ('randomforest', sklearn.ensemble.RandomForestClassifier()),
        #     ('svm', sklearn.svm.SVC(probability=True)),
        #     # ('nusvm', sklearn.svm.NuSVC(probability=True)),
        #     ('gradientboost', sklearn.ensemble.GradientBoostingClassifier()),
        #     ('adaboost', sklearn.ensemble.AdaBoostClassifier())
        # ], voting='soft')

        # classifier = sklearn.tree.DecisionTreeClassifier()
        # classifier = sklearn.svm.SVC(probability=True)

        classifier = sklearn.pipeline.Pipeline(steps=[
            ('scale1', sklearn.preprocessing.RobustScaler()),
            ('pca', sklearn.decomposition.PCA(n_components=100, whiten=True)),
            ('svc', sklearn.svm.SVC(probability=True)),
        ])

        classifier.fit(x, y)

        print('End Train.')

        for split_input_files in split_list(input_files, 1000):
            u = []
            new_input_files = []
            for i, file in enumerate(split_input_files):
                if not os.path.exists(file):
                    continue

                print(f'Processing : {file}')

                try:
                    if os.path.basename(file) in dataset:
                        u.append(dataset[os.path.basename(file)])
                    else:
                        u.append(preprocess(file))

                    new_input_files.append(file)
                except Exception as e:
                    print(e)
                    print(traceback.format_exc())
                    continue

            if len(u) > 0:
                u = numpy.stack(u, axis=0)
                scores = classifier.predict_proba(u)[:, 1].tolist()
                for file, score in zip(new_input_files, scores):
                    if score < 0.01:
                        try:
                            shutil.move(file, os.path.join(false_dir, os.path.basename(file)))
                        except OSError:
                            continue
                    elif score > 0.99:
                        try:
                            shutil.move(file, os.path.join(true_dir, os.path.basename(file)))
                        except OSError:
                            continue

コードが汚い。

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?