読者です 読者をやめる 読者になる 読者になる

グラムシュミッドの正規直交化を書いてみてRcppを比較する

f:id:aaaazzzz036:20140224180040p:plain
Eigen

あっ, どーも僕です.

そういえば, 理工系学者, 学生にとっては垂涎ものであるFreeFEMの日本語参考書がでましたね. FreeFEMの参考書は言語を問わず少ないので, 工学部系のほとんどの研究室で購入しているんじゃないでしょうか.

Rcpp

sourceCpp

訂正:cppFunction(誤)とsourceCpp(正)を間違えていました。。。

本題に入る前にすこしだけ.
RcppでC++関数を作る主な方法として, Rのオブジェクトを渡すcxxfunctionと, .cppを読み込んで作るsourceCppの2つがあります. どっちがいいかと思っていのですが, わたしはsourceCppを推します. なぜかというと, cxxfunctionのエラーメッセージはとてもわかりにくいです. 一方で, sourceCppnは, RStudio上で使うとちゃんと行を示してエラーの箇所を教えてくれます.

RcppEigen


前回, わたしはRcppArmadilloとRcppEigenのどちらが良いかわからないと書きましたが, グラム・シュミットの正規直交化法 - Wikipediaを計算するシミュレーションしたところRcppEigenの方が速かったです. RcppはRcppEigenを使いましょう.


シミュレーションはRcppEigen, RcppArmadillo, Rを用いて自分で書いたものと, Rのパッケージfarでおこないました. スクリプトは下の方に載せておきます. わたしが書いたアルゴリズムが最も速いかはわかりませんが, それぞれ同じアルゴリズムで書いていますので計算量にそこまで差はないかと思います.


下の図は, 1000次実対称行列を正規直交化する, という作業を10回おこなって得た平均所要時間です. Eigenが速いですね. パッケージfarの関数は圧倒的に遅かったので省いています(中身を見たら生Rのベタ書きでした).

f:id:aaaazzzz036:20140224170558j:plain



RcppEigenは速いですが, EigenのHPに書いてあるManualがすごく見づらいです.

スクリプト

全体的に名前があれですが, 察してください.

RcppEigen

.asDiagonal()の使い方に注意が必要かと思います.

#include <RcppEigen.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppEigen")]]
typedef Eigen::Map<Eigen::MatrixXd> MapMatd;

// [[Rcpp::export]]
Eigen::MatrixXd eigen_gramshimidt (MapMatd X, 
                                   bool normalize = true) {
    int n = X.rows(), k = X.cols();
    Eigen::MatrixXd X_orthgonalized(n, k), ProjMat, orthgonalized_elements;
    Eigen::VectorXd base(n), ProjCoef, pjc, pjk;
    X_orthgonalized.col(0) = X.col(0);
    for (int i = 1; i < k; i++) {
        base     = X.col(i);
        ProjMat  = X_orthgonalized.leftCols(i);
        // 訂正 2014 / 6 / 1
        // なぜか, 次のものがちゃんと動作しないのかわからない
        // おそらくは, ProjMat.transpose() * baseがMatrixXdで, 残りがVectorXdで, クラスが異なるため
        // ProjCoef = (ProjMat.transepose() * base).cwiseQuotient (ProjMat.colwise().squaredNorm())
        pjc      = ProjMat.transpose() * base;
        pjk      = ProjMat.colwise().squaredNorm();
       ProjCoef = pjc.cwiseQuotient(pjk);
        orthgonalized_elements = ProjMat * ProjCoef.asDiagonal(); 
        X_orthgonalized.col(i) = base - orthgonalized_elements.rowwise().sum();
    }
    
    if (normalize) {
        // .asDiagonalを使うときは, Vector.asDiagonal()と使う. 
        // .norm.asDiagonalとかはエラーがでないけど対角行列にならない
        Eigen::VectorXd norm_inv = X_orthgonalized.colwise().norm().cwiseInverse();
        X_orthgonalized *= norm_inv.asDiagonal();
    }
    
    
    return  X_orthgonalized;
}
RcppArmadillo
#include <RcppArmadillo.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::mat arma_gramshimidt (arma::mat X,
                            bool normalize = true) { 
    arma::mat X_orthgonalized, ProjMat, orthgonalized_elements;
    arma::colvec base, ProjCoef;
    int k = X.n_cols;
    X_orthgonalized.copy_size(X);
    X_orthgonalized.col(0) = X.col(0);
    for (int i = 1; i < k; i++) {
        base     = X.col(i);
        ProjMat  = X_orthgonalized.cols(0, i-1); // if i = 1, cols(0, i-1) == col(0)
        ProjCoef = (ProjMat.t() * base) / arma::sum ((ProjMat % ProjMat).t(), 1);
        orthgonalized_elements = ProjMat * arma::diagmat (ProjCoef); 
        X_orthgonalized.col(i) = base - arma::sum (orthgonalized_elements, 1);
    }
    
    if (normalize) {
        arma::colvec norms  = arma::sqrt(arma::sum(arma::pow(X_orthgonalized, 2), 0)).t();
        X_orthgonalized    *= arma::diagmat(arma::ones<arma::colvec>(k) / norms);
    }
    
    return X_orthgonalized;
}
R
myDotProduct <- function (A, B = NULL, doSQRT = FALSE) {
    if (!is.matrix (A) & !is.data.frame (A)) {
        warning ("A must be a matrix or data.frame")
    }
    if (is.null (B)) {
        norm2 <- c (colSums (A * A))   
    } else {
        if (!is.matrix (B) & !is.data.frame (B)) {
            warning ("B must be a matrix or data.frame")
        }
        norm2 <- c (colSums (A * B))
    }
   
    if (doSQRT) {
        norm2 <- sqrt (norm2)
    }
    
    return (norm2 = norm2)
}


myOrthogonalize_Schmidt <- function (A, norm = FALSE) {
    if (!is.matrix (A)) {
        warning ("A must be a  matrix")
    }
    if (ncol (A) == 1) {
        return (list (AO = A))
    }
    # 追加関数
    if (!exists ("myDotProduct")) {
        source ("myDotProduct.r")
    }
    
    
    # 直交化
    A_Orthgonalized <- A
    nc <- ncol (A)
    for (i in 2:nc) {
        base      <- A[, i]
        proj      <- A_Orthgonalized[, 1:(i-1), drop = FALSE]
        proj.coef <- - c (crossprod (proj, base)) / myDotProduct (proj)
        orthgonalize_element <- rowSums (sweep (proj, 2, proj.coef, "*"))
        A_Orthgonalized[, i] <- A_Orthgonalized[, i] + orthgonalize_element
    }
    
    
    # 結果を正規化する場合
    if (norm) {
        A_Orthgonalized <- sweep (A_Orthgonalized, 2, myDotProduct (A_Orthgonalized, doSQRT = TRUE), "/")
    } 
    
    
    return (list (AO = A_Orthgonalized))
    
}

シミュレーション
library (Rcpp)
library (RcppArmadillo)
library (RcppEigen)
sourceCpp ("Arma_GramShimidt_orthogonalization.cpp")
sourceCpp ("Eigen_GramShimidt_orthogonalization.cpp")
# 動作確認
library (far) # for orthonomalization
A <- cov (iris[ , -5])
t1 <- orthonormalization (A, norm = TRUE)
t2 <- myOrthogonalize_Schmidt (A, norm = TRUE)$AO
t3 <- arma_gramshimidt(A)
t4 <- eigen_gramshimidt(A)
print (crossprod (t1))
print (crossprod (t2))
print (crossprod (t3))
print (crossprod (t4))

# 速度比較
library(rbenchmark)
N <- 1000
A <- matrix (rnorm (N * N), N, N); A <- tcrossprod (A)
benchmark (orthonormalization (A, norm = TRUE), 
           myOrthogonalize_Schmidt (A, norm = TRUE), 
           arma_gramshimidt(A), 
           eigen_gramshimidt(A), 
           replications = 10, 
           order="relative")
#                                      test replications elapsed relative user.self sys.self
# 4                    eigen_gramshimidt(A)           10   61.12    1.000     46.99    14.00
# 3                     arma_gramshimidt(A)           10  142.15    2.326    120.76    21.03
# 2 myOrthogonalize_Schmidt(A, norm = TRUE)           10  328.41    5.373    286.48    41.04
# 1      orthonormalization(A, norm = TRUE)           10 1855.85   30.364   1810.68    42.21

おわりに

先日のTokyo.RではRcppの話題がでてきませんでしたね. 残念.