1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

重回帰分析をpythonでやってみる

Last updated at Posted at 2022-03-16

回帰分析ってなに? 回帰分析を使ってみる!

単回帰分析とは、1つの目的変数[Y]を1つの説明変数[X]で予測するもので、その2変量の間の関係性をY=aX+bという一次方程式の形で表します。  最小二乗法を使って、そこそこ簡単な計算式で傾き[a]と切片[b]を求めることができるのですが、それを求めてどう使うのでしょう?

  • 一つは、分析してY=aX+bの aとbが求まるのですから、任意のXに対するYを予測することができます。
  • もう一つは、 XとYの相関を得ることができます。 相関は -1.0〜+1.0までの間で、|1.0|に近いほど相関が強く、|0.2|以下くらいだと相関がないと判断できます。

 例えば、あるブランドの服を購入する人の体重を目的変数Yとして、身長を説明変数[X]とするなら、XからYを予測できるってものです。 もちろん身長と体重だけの分析では不足がありますから、正確な予想にはなりませんが、そのブランドの購入者の一つの傾向として細身傾向、中肉傾向、みたいなことが分かる、ってわけです。 

重回帰分析。こっちが解析としては本命かもしれません。結論だけいうなら、
 Y = a1・X1 + a2・X2 + .... + an・Xn のX1〜Xnの n変量の関係性を方程式で表します。 a1〜an の値が大きければ大きいほど、その項目の比重が大きいってことになります。 便利だと思いません?
 いろんな情報から、何が一番大事か、次に何が大事か、が分析できちゃんうんです。
 これも単回帰分析と同じで最小二乗法で求めるのですが、要素数がいっぱいあってそれらの共分散行列の逆行列演算とかが必要で、いちいち計算してられないので、ツールを使っちゃいます。それが scikit-learnってtoolです。

 以降にpythonでの実装方法や動作方法を示しますが、何より意識して欲しいのは、これはあくまで分析のための計算方法ってことです。 例えば重回帰分析にかける説明変数について。これ、何でも良いわけじゃあありません。相関が強い変数同士を指定すると、正確な分析ができなかったりします。 また、そもそも選択した変数が売り上げ向上に結びつかなければ分析に意味なんてありませんし、a1〜an がどれも似たり寄ったりの値では、優先度をつけることができません。検証にはなりますが。
 大事なことは、必要なデータを収集し、分析し、Actionに繋げる、この3つがセットでなければならないのです。ゴミからはゴミしか生まれない。ゴミ・データを解析をしてもゴミ・結果しか生まれない。いわゆるガーベージイン・ガーベージアウトです。 そうでなくするためにゴミ・データをうまく抽出したり、うまく加工したりして、素晴らしい結果を生む。つまり、ガーベージイン・ゴスペルアウトにしなければなければ意味がないのです。
 「想定する結果から、必要なデータを収集する」そんなアプローチも必要だということを肝に銘じて、こうした分析を活用してください。

単回帰分析サンプルプログラムの実行方法

以下は、説明変数を sepal-length とし、目的変数を petal-length とした単回帰分析をしたものです。
動作させると表が出力され、

  • y = ax + b の a を [coef]、bを[intercept]として、また相関を [R] として表示します。
  • -column の値を変更すれば、それぞれの変数の相関を見ることもできます。
  • -query は、データセットから、花の種類を限定するのに使っています。
python3 regression.py -intype csv -in iris.csv -type simple-regression -column sepal-length -column petal-length -query "setosa_versicolor_virginica == 0"

重回帰分析サンプルプログラムの実行方法

以下は4つの説明変数(sepal-length, sepal-width, petal-length, petal-width)ガクの長さ・幅と花弁の長さ・幅と目的変数(setosa_versicolor_virginica)の関係性を得ています。  setosa_versicolor_virginica は花の種類で setosa[0],versicolor[1],virginica[2]を表しています。

python3 regression.py -intype csv -in iris.csv -type multi-regression -column sepal-length -column sepal-width -column petal-length -column petal-width -objcolumn setosa_versicolor_virginica

計算すると以下が出力され、四つ目の説明変数 petal-width 花弁の幅の比重が大きいことが見て取れます。
言い換えると、花弁の長さ|幅で、 setosa[0],versicolor[1],virginica[2]の花は識別しやすい、ってことがわかるのです。

coef([-0.11311259 -0.02132399  0.49268913  0.56686618]) intercept(2.662115665390031e-16)

回帰分析のサンプル実装

データCSV から回帰分析する python プログラム
注)重回帰分析をする際には、データの単位などを標準化して統一する必要があります。
  そのため multi_regression関数の最初に標準化を加筆しました。

regression.py
# install module
#    python3 -m pip install pandas numpy scikit-learn matplotlib

import sys
import pandas as pd                # dataset (pandasを利用)
from sklearn import linear_model   # scikit-leanを使った回帰分析
import matplotlib.pyplot as plt    # グラフの表示用

#
# 単回帰分析
#
def simple_regression(df, targets):
    x = df[[targets[0]]]
    y = df[[targets[1]]]

    # 単回帰分析
    lr = linear_model.LinearRegression()
    lr.fit(x, y)
    r_score = lr.score(x,y)

    equation = 'coef(' + str(lr.coef_[0]) + ') intercept(' + str(lr.intercept_[0]) + ')  R[' + str(r_score) + ']'
    #print('params : ' + str(lr.get_params()))

    # グラフ表示
    plt.figure()
    plt.scatter(x, y)
    plt.plot(x, lr.predict(x), color='red', linewidth=2)
    plt.title(targets[0] + ' vs ' + targets[1])
    plt.xlabel(targets[0])
    plt.ylabel(targets[1])
    plt.legend(['plot', equation])
    plt.show()

#
# 重回帰分析
#
def multi_regression(df, targets, objtarget):
    # standard
    tlen = len(targets)
    df[objtarget] = (df[objtarget] - df[objtarget].mean()) /  df[objtarget].std(ddof=0)  # for standard
    for i in range(0, tlen):
        df[targets[i]] = (df[targets[i]] - df[targets[i]].mean()) / df[targets[i]].std(ddof=0)
    x = df[targets]
    y = df[[objtarget]]

    lr = linear_model.LinearRegression()
    lr.fit(x, y)

    equation = 'coef(' + str(lr.coef_[0]) + ') intercept(' + str(lr.intercept_[0]) + ')'
    print(equation)
    #print('params : ' + str(lr.get_params()))

def procCsv(args):
    # check parameter
    if (args.get('type') is None):
        usage('** error, lack of arg, -type')
    if (args.get('column') is None):
        usage('** error, lack of arg, -column')

    columns = args['column']
    if (not isinstance(columns, list)):
        usage('** error, illegal arg, -column')
    if (len(columns) < 2):
        print('** error, need 2-columns like (x,y) at least')
        return

    sep = ','
    if (args.get('separator') is not None):
        sep = args[separator]

    # load csv to dataframe in pandas
    try:
        df = pd.read_csv(args['in'], sep=sep)
        if (args.get('query') is not None):
            df = df.query(args['query'])
        #print(df)
    except Exception as e:
        print('** error, load csv exception[' + args['in'] + '], ' + str(e))
        return

    # check columns
    dfcolumns = list(df)
    for column in columns:
        if (column not in dfcolumns):
            print('** error, column is not in csv-datas, column[' + column + '] , csv-columns[' + str(dfcolumns) + ']')
            return

    # regression
    if (args['type'] == 'simple-regression'):
        simple_regression(df, columns)

    elif (args['type'] == 'multi-regression'):
        if (args.get('objcolumn') is None):
            usage('** error, lack of args, -objcolumn')
        if (args['objcolumn'] not in dfcolumns):
            print('** error, object-column is not in csv-datas, column[' + args['objcolumn'] + '] , csv-columns[' + str(dfcolumns) + ']')
            return
        multi_regression(df, columns, args['objcolumn'])

    else:
        usage('** error, illegal arg, -type')


def usage(error):
    if (error is not None):
        print(error)
    print('usage:')
    print(' option : -intype {csv} -in {input-file} -type {simple-regression|multi-regression} [argment]')
    print('          -type simple-regression -column {x-column} -column {y-column} [-query {pandas query}]')
    print('          -type multi-regression -column {x1-column} -column {x2-column} [... -column {xn-column}] -objcolumn {column}  [-query {pandas query}]')
    sys.exit()

#
# main
#
def parseArgs(args):
    res = {}
    type = False
    typestr = ''
    nobase64 = False
    for index, arg in enumerate(args):
        # command name
        if (index == 0):
            continue

        if (arg[0] == '-'):
            typestr = arg[1:]
            if (res.get(typestr) is None):
                res[typestr] = ''
            type = True
            continue

        if (isinstance(res[typestr], list)):
            res[typestr].append(arg)
        else:
            if (len(res[typestr]) == 0):
                res[typestr] = arg
            else:
                ldata = [res[typestr]]
                res[typestr] = ldata
                res[typestr].append(arg)
        type = False

    return res


def main(args):
    params = parseArgs(args)
    if (params.get('intype') is None):
        usage('** error, lack of arg, -intype')
    if (params.get('in') is None):
        usage('** error, lack of arg, -in')

    if (params['intype'] == 'csv'):
        procCsv(params)
    else:
        usage('** error, illegal arg, -intype')


if __name__ == '__main__':
    main(sys.argv)

実験用のデータ scikit-learn の付属データ

scikit-learn サイトの iris.csv データ

カラム内容: がく長, がく幅, 花弁長, 花弁幅, 花の種類(sentosa[0]|versicolor[1]|virginica[2])

iris.csv
sepal-length,sepal-width,petal-length,petal-width,setosa_versicolor_virginica
5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
5.0,3.6,1.4,0.2,0
5.4,3.9,1.7,0.4,0
4.6,3.4,1.4,0.3,0
5.0,3.4,1.5,0.2,0
4.4,2.9,1.4,0.2,0
4.9,3.1,1.5,0.1,0
5.4,3.7,1.5,0.2,0
4.8,3.4,1.6,0.2,0
4.8,3.0,1.4,0.1,0
4.3,3.0,1.1,0.1,0
5.8,4.0,1.2,0.2,0
5.7,4.4,1.5,0.4,0
5.4,3.9,1.3,0.4,0
5.1,3.5,1.4,0.3,0
5.7,3.8,1.7,0.3,0
5.1,3.8,1.5,0.3,0
5.4,3.4,1.7,0.2,0
5.1,3.7,1.5,0.4,0
4.6,3.6,1.0,0.2,0
5.1,3.3,1.7,0.5,0
4.8,3.4,1.9,0.2,0
5.0,3.0,1.6,0.2,0
5.0,3.4,1.6,0.4,0
5.2,3.5,1.5,0.2,0
5.2,3.4,1.4,0.2,0
4.7,3.2,1.6,0.2,0
4.8,3.1,1.6,0.2,0
5.4,3.4,1.5,0.4,0
5.2,4.1,1.5,0.1,0
5.5,4.2,1.4,0.2,0
4.9,3.1,1.5,0.2,0
5.0,3.2,1.2,0.2,0
5.5,3.5,1.3,0.2,0
4.9,3.6,1.4,0.1,0
4.4,3.0,1.3,0.2,0
5.1,3.4,1.5,0.2,0
5.0,3.5,1.3,0.3,0
4.5,2.3,1.3,0.3,0
4.4,3.2,1.3,0.2,0
5.0,3.5,1.6,0.6,0
5.1,3.8,1.9,0.4,0
4.8,3.0,1.4,0.3,0
5.1,3.8,1.6,0.2,0
4.6,3.2,1.4,0.2,0
5.3,3.7,1.5,0.2,0
5.0,3.3,1.4,0.2,0
7.0,3.2,4.7,1.4,1
6.4,3.2,4.5,1.5,1
6.9,3.1,4.9,1.5,1
5.5,2.3,4.0,1.3,1
6.5,2.8,4.6,1.5,1
5.7,2.8,4.5,1.3,1
6.3,3.3,4.7,1.6,1
4.9,2.4,3.3,1.0,1
6.6,2.9,4.6,1.3,1
5.2,2.7,3.9,1.4,1
5.0,2.0,3.5,1.0,1
5.9,3.0,4.2,1.5,1
6.0,2.2,4.0,1.0,1
6.1,2.9,4.7,1.4,1
5.6,2.9,3.6,1.3,1
6.7,3.1,4.4,1.4,1
5.6,3.0,4.5,1.5,1
5.8,2.7,4.1,1.0,1
6.2,2.2,4.5,1.5,1
5.6,2.5,3.9,1.1,1
5.9,3.2,4.8,1.8,1
6.1,2.8,4.0,1.3,1
6.3,2.5,4.9,1.5,1
6.1,2.8,4.7,1.2,1
6.4,2.9,4.3,1.3,1
6.6,3.0,4.4,1.4,1
6.8,2.8,4.8,1.4,1
6.7,3.0,5.0,1.7,1
6.0,2.9,4.5,1.5,1
5.7,2.6,3.5,1.0,1
5.5,2.4,3.8,1.1,1
5.5,2.4,3.7,1.0,1
5.8,2.7,3.9,1.2,1
6.0,2.7,5.1,1.6,1
5.4,3.0,4.5,1.5,1
6.0,3.4,4.5,1.6,1
6.7,3.1,4.7,1.5,1
6.3,2.3,4.4,1.3,1
5.6,3.0,4.1,1.3,1
5.5,2.5,4.0,1.3,1
5.5,2.6,4.4,1.2,1
6.1,3.0,4.6,1.4,1
5.8,2.6,4.0,1.2,1
5.0,2.3,3.3,1.0,1
5.6,2.7,4.2,1.3,1
5.7,3.0,4.2,1.2,1
5.7,2.9,4.2,1.3,1
6.2,2.9,4.3,1.3,1
5.1,2.5,3.0,1.1,1
5.7,2.8,4.1,1.3,1
6.3,3.3,6.0,2.5,2
5.8,2.7,5.1,1.9,2
7.1,3.0,5.9,2.1,2
6.3,2.9,5.6,1.8,2
6.5,3.0,5.8,2.2,2
7.6,3.0,6.6,2.1,2
4.9,2.5,4.5,1.7,2
7.3,2.9,6.3,1.8,2
6.7,2.5,5.8,1.8,2
7.2,3.6,6.1,2.5,2
6.5,3.2,5.1,2.0,2
6.4,2.7,5.3,1.9,2
6.8,3.0,5.5,2.1,2
5.7,2.5,5.0,2.0,2
5.8,2.8,5.1,2.4,2
6.4,3.2,5.3,2.3,2
6.5,3.0,5.5,1.8,2
7.7,3.8,6.7,2.2,2
7.7,2.6,6.9,2.3,2
6.0,2.2,5.0,1.5,2
6.9,3.2,5.7,2.3,2
5.6,2.8,4.9,2.0,2
7.7,2.8,6.7,2.0,2
6.3,2.7,4.9,1.8,2
6.7,3.3,5.7,2.1,2
7.2,3.2,6.0,1.8,2
6.2,2.8,4.8,1.8,2
6.1,3.0,4.9,1.8,2
6.4,2.8,5.6,2.1,2
7.2,3.0,5.8,1.6,2
7.4,2.8,6.1,1.9,2
7.9,3.8,6.4,2.0,2
6.4,2.8,5.6,2.2,2
6.3,2.8,5.1,1.5,2
6.1,2.6,5.6,1.4,2
7.7,3.0,6.1,2.3,2
6.3,3.4,5.6,2.4,2
6.4,3.1,5.5,1.8,2
6.0,3.0,4.8,1.8,2
6.9,3.1,5.4,2.1,2
6.7,3.1,5.6,2.4,2
6.9,3.1,5.1,2.3,2
5.8,2.7,5.1,1.9,2
6.8,3.2,5.9,2.3,2
6.7,3.3,5.7,2.5,2
6.7,3.0,5.2,2.3,2
6.3,2.5,5.0,1.9,2
6.5,3.0,5.2,2.0,2
6.2,3.4,5.4,2.3,2
5.9,3.0,5.1,1.8,2

まとめ

 僕はデータサイエンティストではありません。ただのITエンジニアなので、「統計・分析」や「ビジネス・マーケティング」のスキルは初級レベルしか持ち合わせていません。
 ただ、この領域に手をつけるのに「ITスキル」+「解析スキル」だけじゃダメなんです。
 回帰分析は分析方法の一つでしかなく、それも正しい説明要素が選択できなければ、結果に信頼性はまるでありません。言い換えれば、分析以前に必要なデータを収集し、必要に応じた加工が必要なんです。卵が先か鶏が先か、みたいなもので、ご自身の専門はともあれ、そうした意識を持って取り組まないと、何をどう設計し、実装しようとも、有益な結果を得ることができないことを、意識してください。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?