Rでk-means法とその拡張3 改良k-means編

あっ, どーも僕です。


Rでk-means法とその拡張の最後は, 改良k-meansです。

概要は前々回の記事をみてください。また, アルゴリズムは論文をみてみてください。フリーで入手できます。

さっそく, 結果ですが, クラスター数を3にしたときがこちら。相変わらずirisを使います。色がクラスタリングの結果で, 記号はSpeciesで分けています。うまくクラスタリングされたことがわかります。

改良k-meansはどうやらRで実装されていなさそうなので正しく実装できているかは, 論文と結果を比較してみました。比較したところ, 問題なしです。論文では, 3つのデータがクラスタリング失敗していますが, グラフをみると, オレンジの三角が2つと緑の十字がひとつあることがわかります。

k-meansを多変量正規分布で修正するだけで, こんなにもすごい結果がでるのですね。すごい。

f:id:aaaazzzz036:20131126134904p:plain


で, 問題なしといったのですが, ほんとは問題ありです。論文で書いてある通りにAICを計算したつもりなのですが, 論文では40000くらいで, このスクリプトは400くらいなので最後の対数尤度の計算が間違ってそうです.....

もし, なにが間違ってるかお分かりの方がいましたらコメント欄にでもいいので教えて下さい。。。。

注意!!Rのコードを書き直しました。s-kmeans.rは前回のエントリーに載せてあります!!@2014/09/13

# ***********************************************************
# mahalanobis-kmeans.r 
# ***********************************************************
mkmeans <- setRefClass (
	Class = "mkmeans",

	contains = c ("s_kmeans"), 

	fields  = list (NUM_CLUSTER = "integer",
			cur_cluster = "integer", 
			df          = "integer",
			cur_center  = "list",
			cur_covar   = "list",
			AIC         = "numeric",
			BIC         = "numeric",
			logLL       = "numeric",
			dat         = "data.frame") ,

	methods = list (
		initialize = function ()
		{
			.self$addPackages ()
		},
		mahalanobis_kmeans = function (dat_, num_cluster_)
		{
			dat         <<- .self$cleanData (dat_)
			NUM_CLUSTER <<- as.integer (num_cluster_)
			cur_cluster <<- simple_kmeans (dat, NUM_CLUSTER)
			.self$assignDf ()
			.self$updateCenters ()
			.self$updateCovars ()
			convergence <- FALSE
			iter        <- 0
			old_cluster <- cur_cluster
			while (!convergence) {
				iter <- iter + 1
				if (iter > 100)
					break
				logLikelihood_eachCluster <-
				    sapply (seq_len (NUM_CLUSTER),
					    function (k) 
					    {
					     center_ <- cur_center[[k]]
					     covar_  <- cur_covar[[k]]
					     dmvnorm (dat, center_, covar_, log = TRUE)
					  })
				cur_cluster <<- as.integer (max.col (logLikelihood_eachCluster))
				if (length (unique (cur_cluster)) != NUM_CLUSTER)
					cur_cluster <<- simple_kmeans (dat, NUM_CLUSTER)
				convergence <-  identical (old_cluster, cur_cluster)
				old_cluster <-  cur_cluster
				.self$updateCenters ()
				.self$updateCovars ()
			}
			.self$m_logLikelihood ()
			.self$m_AIC ()
			.self$m_BIC ()
		},
		addPackages = function (CRAN = "http://cran.ism.ac.jp/", 
					more = NULL)
		{
			addPackages <- c ("dplyr", "matrixStats", "mvtnorm", more)
			toInstPacks <- setdiff (addPackages, .packages(all.available = TRUE))
			for (pkg in toInstPacks)
			    install.packages (pkg, repos = CRAN)
			for (pkg in addPackages)
			    suppressPackageStartupMessages (library (pkg, character.only = TRUE))
		},
		assignDf = function ()
		{
			K   <- NUM_CLUSTER
			nc  <- ncol (dat)
			df  <<- as.integer ((nc * K) + (K * nc * (nc + 1) * .5) + (K-1))
		},
		updateCenters = function ()
		{
			cbind (dat, cc = cur_cluster) %>%
			    group_by (cc) %>%
			    do (rtn = colMeans (dplyr::select (., -cc))) %>%
			    as.list () %>%
			    "$"(rtn) ->>
			    cur_center
		},
		updateCovars = function ()
		{
			cbind (dat, cc = cur_cluster) %>%
			    group_by (cc) %>%
			    do (rtn = cov (dplyr::select (., -cc))) %>%
			    as.list () %>%
			    "$"(rtn) ->>
			    cur_covar
		},
		cleanData = function (data_)
		{
			out            <- na.omit (data_ [, sapply (data_, is.numeric)])
			rownames (out) <- 1:nrow (out)
			return (out)
		},
		m_logLikelihood = function ()
		{
			logMixingRate <- log (table (cur_cluster)) - log (length (cur_cluster))
			loglikelihood <- sapply (seq_len (NUM_CLUSTER),
						 function (k) 
                                                 {
						  center_ <- cur_center[[k]]
						  covar_  <- cur_covar [[k]]
						  dmvnorm (dat, center_, covar_, log = TRUE) + logMixingRate[k]
						  })
			logLL <<- sum (rowLogSumExps (loglikelihood))
		},
		m_AIC = function ()
		{
			AIC <<- - 2 * logLL + 2 * df
		},
		m_BIC = function ()
		{
			BIC <<- - 2 * logLL + df * log (nrow (dat))
		},
		showIC = function ()
		{
			cat ("NUM_CLUSTER = ", NUM_CLUSTER, "\n")
			cat ("AIC = ", AIC, "\n")
			cat ("BIC = ", BIC, "\n")
		}
	)

)
追記 2013/11/28

最終的に, 4つk-meansをまとめて比較するとこんな感じになりました。各円には, それぞれ二万個のデータ, つまり合計(2万×2)×3要素のデータをクラスタリングしてみました。

何度か試してみるとk-means, fuzzy c-meansはいまk-meanで示しているものか, fuzzy c-meansで示しているもののどちらかのクラスタリングをします。k-means mahalanobis(改良k-means)が一番いいクラスタリングをしており, 円ごとにクラスタリングされているのがわかります。ただ, ときたまこれとは違ったクラスタリングをしてしまいます。x-meansはなんどやっても同じ結果となりました。k-means・c-meansよりも若干いい感じにクラスタリングしているかと思います。
f:id:aaaazzzz036:20131128233300p:plain