はじめに
GLM(一般化線形モデル)などstatsmodelsにしかない機能を使いたいことがありますが、scikit-learnのapiと切り替えて呼び出そうとすると、呼び出しの順序や引数を渡すタイミングが異なっていて不便です。
そこでstatsmodelsのapiをscikit learn風に呼び出せるwapperクラスを作りました。
wapperの定義
class SMWrapper():
def __init__(self, base_cls, fit_intercept=True, **params):
self.base_cls = base_cls
self.fit_intercept = fit_intercept
self.params = params
def fit(self, x, y, sample_weights=None):
_x = sm.add_constant(x) if self.fit_intercept else x
self.base_instance = self.base_cls(y, _x, freq_weights=sample_weights, **self.params)
self.results = self.base_instance.fit()
return self
def score(self, x, y, sample_weights=None):
self.fit(x, y, sample_weights)
return self.base_instance.score(self.results.params)
def predict(self, x):
_x = sm.add_constant(x) if self.fit_intercept else x
return self.base_instance.predict(self.results.params, _x)
使用法
import numpy as np
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
x = np.arange(0, 20)
y = np.random.poisson(x+1, x.shape[0])
# regressor = LinearRegression()
regressor = SMWrapper(sm.GLM, family=sm.families.Poisson())
regressor.fit(x, y)
pred_y = regressor.predict(x)