Help us understand the problem. What is going on with this article?

予測モデルの適用範囲・適用領域を可視化してみる

はじめに

作成した予測モデルにより、あるデータの予測を行う場合、その予測モデルがどのようなデータに対して利用できるかの適用範囲・適用領域 (Applicability Domain, AD) の考え方が重要となる。適用範囲・適用領域を超えて予測を行った場合、その信頼性が問われることになる。

適用範囲・適用領域については、

  • 学習データの中心からの距離
  • データ密度

等、参考URLにも記載されているように、さまざまな方法が存在する。
今回は、その中でも「データ密度」をmatplotlibにより可視化してみた。

データ密度の定義

予測モデル、あるデータがあったときに、そのデータ点におけるデータ密度を以下の通りの定義とした。

そのデータと最も距離が近いN個の学習データに対する距離の総和

いわいるK最近傍法である。距離としては、今回ユークリッド距離を利用した。

可視化方法

  • 学習データはscikit-learnのmake_regressionを用いて生成。データ密度の表現を確認するため、学習データを2つのクラスタに分離するようにした。
  • 説明変数の数は2個とした。PCA、UMAPなどにより3個以上の説明変数を2次元に圧縮して表示することも考えたが、等高線図の作成が難しいため見送りとした。
  • 各格子点におけるデータ密度を算出。(K最近傍法数K=1, 3, 15, 学習データ全部のパターンで算出)
  • 学習データをスキャッタープロットでプロット。データ密度は等高線図により可視化。

 環境

  • python3.6
  • matplotlib 3.1.1
  • scikit-learn 0.21.2

ソース

import numpy as np

from sklearn.datasets import make_regression, make_classification
import argparse
import pandas as pd
import numpy as np
import umap
import matplotlib
import matplotlib.pyplot as plt


def main():

    # --------------------------------------
    # データの生成
    # --------------------------------------
    dataX1, y1 = make_regression(n_samples=20, n_features=2, random_state=0)
    dataX2, y2 = make_regression(n_samples=30, n_features=2, random_state=0)

    dataX1 = dataX1 - 2
    dataX2 = dataX2 + 2
    datas = np.vstack((dataX1, dataX2))

    x = np.arange(-8, 8, 0.4)
    y = np.arange(-8, 8, 0.4)

    X, Y = np.meshgrid(x, y)
    Z = np.zeros([X.shape[0], X.shape[1]])

    # ------------------------------
    # データ密度の計算
    # ------------------------------
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            query_x = X[i, j]
            query_y = Y[i, j]

            #count = len(datas)      #考慮する近傍の数
            count = 15  # 考慮する近傍の数
            #count = 3  # 考慮する近傍の数
            #count = 1  # 考慮する近傍の数
            distances = [] #近傍を格納する配列

            for target_x, target_y in datas:
                distance = np.sqrt(np.square(query_x-target_x) + np.square(query_y-target_y))

                if len(distances) < count:
                    distances.append(distance)
                    print("NEW IN {0} len={1}".format(distance, len(distances)))
                else:
                    distances.sort(reverse=True)
                    for k, d in enumerate(distances):
                        if d > distance:
                            distances[k] = distance
                            print("IN {0} len={1}".format(distance, len(distances)))
                            print("OUT {0} len={1}".format(distance, len(distances)))

            sum = 0.0
            for d in distances:
                sum += d

            print("SUM={0}".format(sum))
            Z[i, j] = sum

    fig = plt.figure(figsize=(12, 5))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.pcolormesh(X, Y, Z, cmap='hsv')
    ax1.scatter(datas[:, 0], datas[:, 1])

    plt.show()

if __name__ == "__main__":
    main()

可視化

全データ(50個)を利用した場合

image.png

K=15の場合

image.png

K=3の場合

image.png

K=1の場合

image.png

考察

  • 全データを利用した場合、データの中心が最も密度が高くなるという、予想された結果となった。この例でいうと2つのクラスタの間の密度が少ないところも密度が高いと評価されてしまうことになるため、データ密度の評価としては問題がある。
  • K=15, K=3ともにクラスタの真ん中部分の密度が低くなっており、いい感じの表示になっている。
  • K=1の場合は、左下の孤立した1点の周りのデータ密度も高くなっており、やや極端な気がする。

参考

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away