LoginSignup
1
0

More than 1 year has passed since last update.

Pyroで多項式線形回帰

Posted at

環境

  • Python 3.9.6
  • PyTorch 1.9.0
  • Pyro 1.7.0

目的

 PyTorchをベースに確率的生成モデルが扱えるPyroでLASSOやRidgeなどの多項式線形回帰を行います。これ自体ができても全く嬉しくないし役にも立ちませんが、多項式の係数を確率変数のベクトルとしてまとめて生成させる方法が見当たらなかったのでその練習として取り組みました。確率的生成モデルの正しい表記方法をいまだに理解していないので数式はほとんどありません。

モデル化

単回帰

 まず単回帰の場合をみます。詳細は先人達の記事を参照ください。12最小二乗法のモデルをPyroで表現するとこうなるはずです。(動作未確認)

OLS.py
import torch
import pyro
import pyro.distribution as dist
from pyro.distribution import constraints


def model(X, Y=None):
    #ただのパラメータはpyro.paramで定義する
    w0 = pyro.param('w0', torch.tensor(0.))
    w1 = pyro.param('w1', torch.tensor(0.))

    #分散が負だとエラーが出るのでconstraints.positiveで正数に限定
    #他によく見たのはconstraints.unit_intervalで[0, 1]の区間に限定
    sigma = pyro.param('sigma', torch.tensor(1.),
                       constraint=constraints.positive)

    #観測点Xが複数あってベクトルになっている場合はwith構文が必要
    #この'Y'をユーザー側でいつ使うかは不明
    with pyro.plate('Y'):
        y_ = w0 + w1*X

        #確率変数はpyro.sampleで確率分布から生成させる
        #この確率変数にアクセスするにはpyro上で'obs'を指定する
        y = pyro.sample('obs', dist.Normal(y_, sigma), obs=Y)

    return y_    

 これを最適化するとパラメータが決まります。最適化されたパラメータにはpyro.param('w0')などでアクセスできます。

 続いてRidge回帰です。Ridgeは係数の事前分布に平均0のガウス分布を採用し、事後確率を最大化することと同値です。すなわちMAP推定を行います

Ridge.py
import torch
import pyro
import pyro.distribution as dist


def model(X, Y=None):
    #分散をパラメータ化
    scale_0 = pyro.param('scale_0', torch.tensor(1.),
                         constraint=constraints.positive)
    #ガウス分布から確率変数を与える
    w0 = pyro.sample('w0', dist.Normal(0., scale_0))

    scale_1 = pyro.param('scale_1', torch.tensor(1.),
                         constraint=constraints.positive)
    w1= pyro.sample('w1', dist.Normal(0., scale_1))

    with pyro.plate('Y'):
        y_ = w0 + w1*X
        y = pyro.sample('obs', dist.Normal(y_, sigma), obs=Y)

    return y_


#変分ベイズ用の近似事後分布
#デルタ関数を使うと確率変数のMAP推定(点推定)にできる
def guide(X, Y=None):
    loc_0 = pyro.param('loc_0', torch.tensor(0.))

    #確率変数をモデル内と同じ名前にするとPyroが認識できる
    pyro.sample('w0', dist.Delta(loc_0))

    loc_1 = pyro.param('loc_1', torch.tensor(0.))
    pyro.sample('w1', dist.Delta(loc_1))

 これを推論機構に投げればPyroでRidge回帰ができて、最も過学習するようにハイパーパラメータも決まります。scikit-learnなどで交差検証で汎化性を持たせた場合と異なり予測器としてはゴミです。そもそも係数が疎になることが売りのRidgeを単回帰に適用することがナンセンス。あくまで例です。

多項式回帰

 ここから本題です。まず重回帰は変数ベクトル$\boldsymbol{x}^T = (1, x_1, x_2, \cdots ,x_n)^T$を用いれば定数項$w_0$も係数ベクトルにまとめて$y = \boldsymbol{x}^T \cdot \boldsymbol{w}$と書けてしまいます。さらに$y_i = \boldsymbol{x}^T_i \cdot \boldsymbol{w}$を全てまとめてしまえば計画行列$X$を用いて$\boldsymbol{y} = X\boldsymbol{w}$です。
 従ってPyroでも要素ごとに扱うなんて面倒なことをせずにベクトル$\boldsymbol{w}$のまま扱えた方が楽です。それには独立な確率変数をまとめて生成できるpyro.plateを使います。確率分布の引数にベクトルを指定している場合はpyro.plate(name)で十分ですが、スカラーを与えてベクトル化をした場合にはpyro.plate(name, dim)とベクトル化の次元数を指定する必要があります。LASSO回帰をモデル化してMAP推定するには次のようにします。

LASSO.py
from pyro.distribution import Laplace
def model(X, Y=None):
    #頭に1を加える
    X_ = torch.cat([torch.ones(X.shape[0], 1), X], dim=1)

    scale = pyro.param('scale', torch.tensor(1.),
                         constraint=constraints.positive)     
    noise = pyro.param('noise', torch.tensor(1.),
                         constraint=constraints.positive)

    #多項式の項数だけ確率変数Wを生じさせるためにX_.shape[1]を与える
    with pyro.plate('w', X_.shape[1]):
        W = pyro.sample('W', Laplace(0., scale))
        #W.shape = X_.shape[1]

    with pyro.plate('Y'):
        y_ = X_@W
        y = pyro.sample('obs', dist.Normal(y_, noise), obs=Y)

    return y_


def guide(X, Y=None):
    dims = X.shape[1]
    noise = pyro.param('_noise', torch.tensor(1.),
                        constraint=constraints.positive)
    pyro.sample('noise', dist.Delta(_noise))

    loc = pyro.param('loc', torch.tensor(1.).expand(dims+1))

    #確率分布の引数locが既にベクトルのためpyro.plateは名前だけ指定すればいい
    with pyro.plate('w'):
        pyro.sample('W', dist.Delta(loc))

 以上、分かればできるけど明記されてなくて苦労しました。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0