ひらめの日常

日常のメモをつらつらと

Rでガウス最尤推定(フィッシャーの線形判別分析)

全ての実装はこちらにあります。
7/5:リンク先更新しました。

github.com

ガウス最尤推定

流れ

ガウスモデルに標本が従っていると仮定します。  

そのガウスモデルから計算した対数事後確率が最も高いクラスに、標本を振り分けると考えます。
そのために、ガウスモデルのパラメータ推定→事後確率計算という流れを経ることになります。

定義

ガウスモデルとは正規分布のことです。
こちらの記事でも紹介していますが、定義は以下のようになっています。

hiramekun.hatenablog.com

{ \displaystyle
f(x) = \frac{1}{(2\pi)^{\frac{D}{2}}|\Sigma|^{\frac{1}{2}}}\exp(-\frac{1}{2}(x-\mu)^{T}\Sigma^{-1}(x-\mu))
}

{D}:次元
{x}:D次元ベクトル
{\Sigma}:xの分散共分散行列
{\mu}:D次元の分散ベクトル

最尤推定

最尤推定自体に関する説明は省きます。
各カテゴリごとに異なるガウスモデルを仮定した場合(つまり、 {y_i} ごとに異なる平均や分散の値を持っているということです)、上記のガウスモデルの最尤推定量は次のように計算されます。
 {\displaystyle
\hat{\mu_y} = \frac{1}{n_y}\sum_{i:y_i = y}x_i
}

 {\displaystyle
\hat{\Sigma_y} = \frac{1}{n_y}\sum_{i:y_i = y}(x_i - \hat{\mu_y})(x_i - \hat{\mu_y})^{T}
}

{\sum_{i:y_i = y}}{y_i = y} を満たす  {i} に関する和

対数事後確率

ベイズの定理と事後確率自体に関する説明は省きます。
対数事後確率は、計算すると以下のようになります。
 {\displaystyle
\log{\hat{p}(y|x)} = -\frac{1}{2}(x-\hat{\mu_y})^{T}{\hat{\Sigma_y}}^{-1}(x-\hat{\mu_y}) - \frac{1}{2}\log{|\hat{\Sigma_y}|} + \log{n_y} + C
}

やや複雑な形になるので、今回は各カテゴリの分散共分散行列が等しい時を考えます(つまり、 {y_i} の分散は等しく、同じ散らばりを持った分布であるとします)。

その時、対数事後確率は以下のように簡略化できます。
 {\displaystyle
\log{\hat{p}(y|x)} = \hat{\mu_y}^{T}{\hat{\Sigma_y}}^{-1}x - \frac{1}{2}\hat{\mu_y}^{T}{\hat{\Sigma_y}}^{-1}\hat{\mu_y} + \log{n_y} + C^{'}
}

フィッシャーの線形判別分析

この、分散共分散行列が共通の時、決定境界は超平面になります。
この場合をフィッシャーの線形判別分析と呼びます。

Rで手書き文字を分類する

事後確率計算

trainデータの分散共分散行列と平均を受け取って、testデータに対する事後確率を計算します。

calc_possibility = function(mean, invs, test_label) {
  mean %*% invs %*% test_label - (mean %*% invs %*% t(mean) / 2)[1][1]
}

結果

2クラス分類

まずは1と2を分類するタスクを実装しました。

正確には、2のtestデータセットを与え、1と判別するか2と判別するかを測定しました。
すると、testデータセット200に対して精度が99%と出ました。

[1] 0.99

そこで間違えて分類している2枚の画像はこんな感じ。
f:id:thescript1210:20180506051123p:plain

多クラス分類

今度は0~9までの数字全てに対して一度に分類を行ってみました。

それぞれの数字に対する精度を出力しました。

[1] "Accuracy of 1: 0.995000"
[1] "Accuracy of 2: 0.845000"
[1] "Accuracy of 3: 0.905000"
[1] "Accuracy of 4: 0.910000"
[1] "Accuracy of 5: 0.815000"
[1] "Accuracy of 6: 0.925000"
[1] "Accuracy of 7: 0.905000"
[1] "Accuracy of 8: 0.825000"
[1] "Accuracy of 9: 0.910000"
[1] "Accuracy of 0: 0.960000"

平均は約89%で、もちろんdeepには及ばないものの、簡単な実装でそこそこの精度が出ることがわかりました。

[1] 0.8995