はじめに

GLM(一般化線形モデル)などstatsmodelsにしかない機能を使いたいことがありますが、scikit-learnのapiと切り替えて呼び出そうとすると、呼び出しの順序や引数を渡すタイミングが異なっていて不便です。

そこでstatsmodelsのapiをscikit learn風に呼び出せるwapperクラスを作りました。

wapperの定義

class SMWrapper():
    def __init__(self, base_cls, fit_intercept=True, **params):
        self.base_cls = base_cls
        self.fit_intercept = fit_intercept 
        self.params = params

    def fit(self, x, y, sample_weights=None):
        _x = sm.add_constant(x) if self.fit_intercept else x
        self.base_instance = self.base_cls(y, _x, freq_weights=sample_weights, **self.params)
        self.results = self.base_instance.fit()
        return self

    def score(self, x, y, sample_weights=None):
        self.fit(x, y, sample_weights)
        return self.base_instance.score(self.results.params)

    def predict(self, x):
        _x = sm.add_constant(x) if self.fit_intercept else x
        return self.base_instance.predict(self.results.params, _x)

使用法

import numpy as np
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm

x = np.arange(0, 20)
y = np.random.poisson(x+1, x.shape[0])

# regressor = LinearRegression()
regressor = SMWrapper(sm.GLM, family=sm.families.Poisson())
regressor.fit(x, y)
pred_y = regressor.predict(x)
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.