LoginSignup
3
1

More than 5 years have passed since last update.

statsmodelsのapiをsklearn likeなwrapper経由で呼び出す

Last updated at Posted at 2018-03-29

はじめに

GLM(一般化線形モデル)などstatsmodelsにしかない機能を使いたいことがありますが、scikit-learnのapiと切り替えて呼び出そうとすると、呼び出しの順序や引数を渡すタイミングが異なっていて不便です。

そこでstatsmodelsのapiをscikit learn風に呼び出せるwapperクラスを作りました。

wapperの定義

class SMWrapper():
    def __init__(self, base_cls, fit_intercept=True, **params):
        self.base_cls = base_cls
        self.fit_intercept = fit_intercept 
        self.params = params

    def fit(self, x, y, sample_weights=None):
        _x = sm.add_constant(x) if self.fit_intercept else x
        self.base_instance = self.base_cls(y, _x, freq_weights=sample_weights, **self.params)
        self.results = self.base_instance.fit()
        return self

    def score(self, x, y, sample_weights=None):
        self.fit(x, y, sample_weights)
        return self.base_instance.score(self.results.params)

    def predict(self, x):
        _x = sm.add_constant(x) if self.fit_intercept else x
        return self.base_instance.predict(self.results.params, _x)

使用法

import numpy as np
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm

x = np.arange(0, 20)
y = np.random.poisson(x+1, x.shape[0])

# regressor = LinearRegression()
regressor = SMWrapper(sm.GLM, family=sm.families.Poisson())
regressor.fit(x, y)
pred_y = regressor.predict(x)
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1