001package org.opengion.penguin.math.statistics; 002 003import java.util.Arrays; 004 005import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; 006 007/** 008 * apache.commons.mathを利用したOLS重回帰計算のクラスです。 009 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。 010 * c0の切片を考慮するかどうかはnoInterceptで決めます。 011 * 012 */ 013//public class HybsMultiRegression { 014public class HybsMultiRegression implements HybsRegression { 015 private double cnst[]; // 各係数(xの種類+1になる?) 016 private double rsquare; // 決定係数 017 private boolean noIntercept; //切片を利用するかどうか 018 019 /** 020 * コンストラクタ。 021 * 与えた二次元データを元に重回帰を計算します。 022 * xデータとして二次元配列を与えます。 023 * noInterceptで切片有り無しを選択します。 024 * @param in_x 説明変数 025 * @param in_y 目的変数 026 * @param noIntercept 切片利用有無(trueで利用しない) 027 */ 028 public HybsMultiRegression(final double[][] in_x, final double[] in_y, final boolean noIntercept){ 029 train( in_x, in_y, noIntercept ); 030 031// this.noIntercept = noIntercept; 032// 033// // ここで重回帰計算 034// OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 035// regression.setNoIntercept(noIntercept); 036// regression.newSampleData(in_y, in_x); 037// 038// cnst = regression.estimateRegressionParameters(); 039// rsquare = regression.calculateRSquared(); 040 } 041 042 /** 043 * コンストラクタ。 044 * 係数配列を与えられるようにしておきます。 045 * (以前に計算したものを利用) 046 * @param in_c 係数配列 047 * @param noIntercept 切片利用有無(trueで利用しない) 048 * 049 */ 050 public HybsMultiRegression( final double[] in_c, final boolean noIntercept){ 051 this.cnst = in_c; 052 this.noIntercept = noIntercept; 053 } 054 055 /** 056 * 与えた二次元データを元に重回帰を計算します。 057 * xデータとして二次元配列を与えます。 058 * noInterceptで切片有り無しを選択します。 059 * 060 * @param in_x 説明変数 061 * @param in_y 目的変数 062 * @param noIntercept 切片利用有無(trueで利用しない) 063 */ 064 private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 065 this.noIntercept = noIntercept; 066 067 // ここで重回帰計算 068 final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 069 regression.setNoIntercept(noIntercept); 070 regression.newSampleData(in_y, in_x); 071 072 cnst = regression.estimateRegressionParameters(); 073 rsquare = regression.calculateRSquared(); 074 } 075 076// /** 077// * 係数の取得。 078// * @return 係数配列 079// */ 080// public double[] getParam(){ 081// return cnst; 082// } 083 084 /** 085 * 係数をセットした配列を返します。 086 * 087 * @return 各係数の配列 088 */ 089 @Override 090 public double[] getCoefficient() { 091 return Arrays.copyOf( cnst,cnst.length ); 092 } 093 094 /** 095 * 配列の内容を係数としてセットします。 096 * 097 * @param in_c 係数配列 098 */ 099 public void setCoefficient(final double[] in_c){ 100 cnst = in_c; 101 } 102 103 /** 104 * 決定係数の取得。 105 * @return 決定係数 106 */ 107 public double getRSquare(){ 108 return rsquare; 109 } 110 111 /** 112 * 計算( c0 + c1x1...)を行う。 113 * noInterceptによってc0の利用を決める。 114 * xの大きさが足りない場合は0を返す。 115 * 116 * @param in_x 必要な大きさの変数配列 117 * @return 計算結果 118 */ 119 public double predict(final double... in_x){ 120 double rtn = 0; 121 int itr = noIntercept ? 0 : 1; 122 if( in_x.length < cnst.length-itr ){ 123 return 0; 124 } 125 126 for( int i=0; i < in_x.length; i++ ){ 127 rtn = rtn + in_x[i] * cnst[i+itr]; 128 } 129 if( !noIntercept ){ rtn = rtn + cnst[0]; } 130 131 return rtn; 132 } 133 134 /*** ここまでが本体 ***/ 135 /*** ここからテスト用mainメソッド ***/ 136 /** 137 * @param args *****************************************/ 138 public static void main(final String [] args) { 139 // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより 140 double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 }; 141 double[][] x = new double[10][]; 142 x[0] = new double[] { 165, 65 }; 143 x[1] = new double[] { 170, 68 }; 144 x[2] = new double[] { 172, 70 }; 145 x[3] = new double[] { 175, 65 }; 146 x[4] = new double[] { 170, 80 }; 147 x[5] = new double[] { 172, 85 }; 148 x[6] = new double[] { 183, 78 }; 149 x[7] = new double[] { 187, 79 }; 150 x[8] = new double[] { 180, 95 }; 151 x[9] = new double[] { 185, 97 }; 152 153 154 HybsMultiRegression mr = new HybsMultiRegression(x,y,true); 155 156 System.out.println( mr.getRSquare() ); 157 System.out.println( Arrays.toString( mr.getCoefficient()) ); 158 159 System.out.println( mr.predict( new double[] { 169,85 } )); 160 } 161} 162