001package org.opengion.penguin.math.statistics; 002 003import org.apache.commons.math3.stat.StatUtils; 004import org.apache.commons.math3.linear.RealMatrix; 005import org.apache.commons.math3.linear.Array2DRowRealMatrix; 006import org.apache.commons.math3.linear.LUDecomposition; 007import org.apache.commons.math3.stat.correlation.Covariance; 008 009/** 010 * apache.commons.mathを利用した、マハラノビス距離関係の処理クラスです。 011 * 012 * 相関を考慮した距離が求まります。 013 * 教師無し学習的に、異常値(外れ値)検知に利用可能です。 014 * 閾値は95%区間の2.448がデフォルトです。(3なら99%) 015 * 016 * 「Juan Francisco Quesada-Brizuela」氏の距離計算PGを参照しています。 017 * 学術的には様々な改良が提案されていますが、このクラスでは単純なマハラノビス距離を扱います。 018 */ 019public class HybsMahalanobis { 020 021 private RealMatrix dataMatrix; //与えたデータ 022 private double[] dataDistance; // 元データの各マハラノビス距離 023 private double[] average; // 平均 024 private RealMatrix covariance; //共分散 025 private double limen=2.448; // 異常値検知をする際の閾値(初期値は95%信頼楕円) 026 027 /** 028 * コンストラクタ。 029 * 与えたデータマトリクスを元にマハラノビス距離を求めるための準備をします。 030 * (平均と共分散を求めます) 031 * 引数calcにtrueをセットすると各点のマハラノビス距離を計算します。 032 * 033 * データ = { { 90 ,60 }, { 70, 80 } } 034 * のような形としてデータを与えます。 035 * 036 * @param matrix 値のデータ 037 * @param calc 距離計算を行うかどうか 038 */ 039 public HybsMahalanobis(final double[][] matrix, final boolean calc){ 040 // 一応元データをセットしておく 041 this.dataMatrix = new Array2DRowRealMatrix(matrix); 042 043 // 共分散行列を作成 044 covariance = new Covariance(matrix).getCovarianceMatrix(); 045 //平均の配列を作成 046 average = new double[matrix[0].length]; 047 for( int i=0; i<matrix[0].length; i++){ 048 average[i] = StatUtils.mean(dataMatrix.getColumn(i)); 049 } 050 051 if(calc){ 052 dataDistance = new double[matrix.length]; 053 for( int i=0; i< matrix.length; i++ ){ 054 // dataDistance[i] = distance( matrix[i] ); 055 dataDistance[i] = distance( covariance,matrix[i],average ); // PMD:Overridable method 'distance' called during object construction 056 } 057 // 標準偏差、平均を取る場合 058 //double maxDst = StatUtils.max( dataDistance ); 059 //double vrDst = StatUtils.variance( dataDistance ); 060 //double shigma = Math.sqrt(vrDst); 061 //double meanDst = StatUtils.mean( dataDistance ); 062 } 063 } 064 065 /** 066 * 距離計算がtrueの形の簡易版コンストラクタです。 067 * 068 * @param matrix 値データ 069 */ 070 public HybsMahalanobis(final double[][] matrix){ 071 this(matrix,true); 072 } 073 074 /** 075 * コンストラクタ。 076 * 計算済みの共分散と平均、閾値を与えるパターン。 077 * 078 * @param covarianceData 共分散 079 * @param averageData 平均配列 080 */ 081 public HybsMahalanobis(final double[][] covarianceData, final double[] averageData){ 082 this.covariance = new Array2DRowRealMatrix(covarianceData); 083 this.average = averageData; 084 } 085 086 087 /** 088 * 平均配列を返します。 089 * 090 * @return 平均 091 */ 092 public double[] getAverage(){ 093 return average; 094 } 095 096 /** 097 * 共分散配列を返します。 098 * 099 * @return 共分散 100 */ 101 public double[][] getCovariance(){ 102 return covariance.getData(); 103 } 104 105 /** 106 * 閾値を返します。 107 * 108 * @return 閾値 109 */ 110 public double getLimen(){ 111 return limen; 112 } 113 114 /** 115 * 平均配列をセットします。 116 * 117 * @param ave 平均 118 */ 119 public void setAverage( final double[] ave ){ 120 this.average = ave; 121 } 122 123 /** 124 * 共分散配列をセットします。 125 * 126 * @param cvr 共分散 127 */ 128 public void setCovariance( final double[][] cvr ){ 129 this.covariance = new Array2DRowRealMatrix(cvr); 130 } 131 132 /** 133 * 閾値をセットします。 134 * 距離の二乗がカイ2乗分布となるため、 135 * 初期値は2.448で、95%区間を意味します。 136 * 2が86%、3が99%です。 137 * 138 * @param lim 閾値 139 */ 140 public void setLimen( final double lim ){ 141 this.limen = lim; 142 } 143 144 /** 145 * コンストラクタで元データを与え、計算させた場合のマハラノビス距離の配列を返します。 146 * 147 * @return 各点のマハラノビス距離の配列 148 */ 149 public double[] getDataDistance(){ 150 return dataDistance; 151 } 152 153 /** 154 * マハラノビス距離を計算します。 155 * 156 * @param vec 判定する点(ベクトル) 157 * @return マハラノビス距離 158 */ 159 public double distance( final double[] vec){ 160 return distance( covariance, vec, average ); 161 } 162 163 /** 164 * 与えたベクトルが閾値を超えたマハラノビス距離かどうかを判定します。 165 * 閾値以下ならtrue、超えている場合はfalseを返します。 166 * (異常値判定) 167 * 168 * @param vec 判定する点(ベクトル) 169 * @return 閾値以下かどうか 170 */ 171 public boolean check( final double[] vec){ 172 final double dist = distance( covariance, vec, average ); 173 return ( dist <= limen ); 174 } 175 176 /** 177 * 平均、共分散を利用して対象ベクトルとの距離を測ります。 178 * 179 * @param mtx1 共分散行列 180 * @param vec1 距離を測りたいベクトル 181 * @param vec2 平均ベクトル 182 * @return マハラノビス距離 183 */ 184 private double distance(final RealMatrix mtx1, final double vec1[], final double vec2[]) { 185 // マハラノビス距離の公式 186 // マハラノビス距離 = (v1-v2)*inv(m1)*t(v1-v2) 187 // inv():逆行列 188 // t():転置行列 189 190 // ※getDeterminantは行列式(正方行列に対して定義される量)を取得 191 // javaの処理上、v1.lengthが2以上の場合、1/(v1.length)が0になる。 192 // その結果、行列式を0乗になるので、detに1が設定される。 193 // この式はマハラノビス距離を求める公式にない為、不要な処理? 194 final double det = Math.pow((new LUDecomposition(mtx1).getDeterminant()), 1/(vec1.length)); 195 196 double[] tempSub = new double[vec1.length]; 197 198 // (x - y)を計算 199 for(int i=0; i < vec1.length; i++){ 200 tempSub[i] = (vec1[i]-vec2[i]); 201 } 202 203 double[] temp = new double[vec1.length]; 204 205 // (x - y) * det 不要な処理? 206 for(int i=0; i < temp.length; i++){ 207 temp[i] = tempSub[i]*det; 208 } 209 210 // m2: (x - y)を行列に変換 211 final RealMatrix m2 = new Array2DRowRealMatrix(new double[][] { temp }); 212 213 // m3: m2 * 共分散行列の逆行列 214 final RealMatrix m3 = m2.multiply(new LUDecomposition(mtx1).getSolver().getInverse()); 215 216 // m4: m3 * (x-y)の転置行列 217 final RealMatrix m4 = m3.multiply((new Array2DRowRealMatrix(new double[][] { temp })).transpose()); 218 219 // m4の平方根を返す 220 return Math.sqrt(m4.getEntry(0, 0)); 221 } 222 223 // *** ここまでが本体 *** 224 225 /** 226 * ここからテスト用mainメソッド。 227 * 228 * @param args **************************************** 229 */ 230 public static void main( final String [] args ) { 231 // 幾何的には、これらの重心を中心とした楕円の中に入っているかどうかを判定 232 final double[][] data = { 233 {2, 10}, 234 {4, 21}, 235 {6, 27}, 236 {8, 41}, 237 {10, 50} 238 }; 239 240 final double[] test = {12, 50}; 241 final double[] test2 = {12, 59}; 242 243 final HybsMahalanobis rtn = new HybsMahalanobis(data); 244 245 System.out.println( java.util.Arrays.toString(rtn.getDataDistance()) ); 246 247 System.out.println(rtn.check( test )); 248 System.out.println(rtn.check( test2 )); 249 } 250} 251