hiramekunのブログ

プログラミングと読書と

python+scikit-learnで多項式曲線近似をリッジ回帰で求める

はじめに

実装はこちらに全て載っています。
7/5:リンク先変更しました。 github.com

多項式曲線近似とは

多項式曲線近似とは、関数があったとき、それを次のような多項式で近似することです。
{ \displaystyle
y(x, w) = w_0 + w_1x + w_2x^{2} + ... + w_Mx^{M} = \sum_{j=0}^{M} w_jx^{j}
}
ここで、Mはパラメータで、自分で近似する多項式の次数を決定します。一般的に次数が高すぎると過学習してしまい、学習データセットに対する精度は上がりますが、汎化性能については低くなってしまうことが多いです。

これを一般的に、線形モデルと言います。

係数の求め方

ここで、線形モデルの各係数wを求める必要があります。
多項式近似して予測した値と、実際の値との誤差関数を最小化するという方針で求めます。

一般的な回帰モデル

誤差関数は以下が用いられることが多いです。二乗誤差の最小値を求めに行くので、最小二乗法と呼ばれます。

{ \displaystyle
E(w) = \frac{1}{2} \sum_{n=1}^{N} \{y(x_n, w) - t_n\}^{2}
}
ここで、y(x_n, w)多項式によって予測された推定値。t_nは正解の値です。この誤差関数をw_j微分して最小値を求めます。

正則化した回帰モデル

一般的な回帰モデルだと、元の関数から大きくずれることがあります。それは、学習データセットに対して適合し過ぎてしまい、大きい係数が出てきてしまうためです。

そこでリッジ回帰モデルでは、罰則項を付加することにより過学習を防ぎます。この作業は正則化(regularization)と呼ばれています。

具体的には以下のような誤差関数によって回帰モデルを求めます。
最後に係数を二乗で足しているのがわかりますね。
{ \displaystyle
E(w) = \frac{1}{2} \sum_{n=1}^{N} \{y(x_n, w) - t_n\}^{2} + a|w|^{2}
}
ここで、aはパラメータで自分で値を設定します。

実装

元の関数

python+sklearnで実装しました。
元の関数、そしてそこから外れた値を生成するためにrandomな値を加算する関数を定義します。

def function(x):
    return 5 * np.sin(x * np.pi / 10)

def function_with_random(x):
    return function(x) + np.random.uniform(-5.0, 5.0)

多項式の定義

sklearnには、PolynomialFeaturesというクラスがあり、今回6次の多項式を定義しました。

from sklearn.preprocessing import PolynomialFeatures
polynomial_features = PolynomialFeatures(degree=6)
X_poly = polynomial_features.fit_transform(X)

このfit_transform()ですが、これは定義した多項式に対して、引数で与えられたXの値を代入した配列が入っています。今回は以下のようになります。

[[1, 0, 0, 0, 0, 0, 0],  # x = 0
[1, 1, 1, 1, 1, 1, 1, 1],  # x = 1
[1, 2, 4, 8, 16, 32, 64],  # x = 2
...
]

学習

次にやることは、先ほど各 1, x, x^{2}, x^{3}...の値を計算したので、それらの線形結合で予測値を出すということです。

from sklearn.linear_model import Ridge
poly_reg = Ridge(alpha=10, fit_intercept=False)
reg  = poly_reg.fit(X_poly, y_train)  # X_polyの線形結合によりyの値を学習する

予測

seabornを使ってグラフを描画しました。

y_ridge = poly_reg.predict(X_poly)  # 学習したモデルを用いて予測する。
sns.pointplot(x=X.reshape(len(X)), y= y_ridge, color='R', markers="")

赤がリッジ回帰によって得られた曲線。青が比較のために実装した正則化しない線形回帰によって得られた曲線です。
確かに外れ値の値を受けにくいような気がします。
f:id:thescript1210:20180409010132p:plain
モデルの平均二乗誤差は次のようになり、確かにリッジ回帰モデルでは過学習が抑制され、汎化性能を保っていることがわかります。

rigde_train_error: 3.8009207105204244
rigde_test_error: 2.665180445057409
linear_train_error: 1.987241639408757
linear_test_error: 3.6306109275834806