001package org.opengion.penguin.math.statistics; 002 003import java.util.Arrays; 004 005import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; 006 007/** 008 * apache.commons.mathを利用したOLS重回帰計算のクラスです。 009 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。 010 * c0の切片を考慮するかどうかはnoInterceptで決めます。 011 * 012 */ 013public class HybsMultiRegression implements HybsRegression { 014 private double cnst[]; // 各係数(xの種類+1になる?) 015 private double rsquare; // 決定係数 016 private boolean noIntercept; //切片を利用するかどうか 017 018 /** 019 * コンストラクタ。 020 * 与えた二次元データを元に重回帰を計算します。 021 * xデータとして二次元配列を与えます。 022 * noInterceptで切片有り無しを選択します。 023 * 024 * @param in_x 説明変数 025 * @param in_y 目的変数 026 * @param noIntercept 切片利用有無(trueで利用しない) 027 */ 028 public HybsMultiRegression( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 029 train( in_x, in_y, noIntercept ); 030 } 031 032 /** 033 * 与えた二次元データを元に重回帰を計算します。 034 * xデータとして二次元配列を与えます。 035 * noInterceptで切片有り無しを選択します。 036 * 037 * @param in_x 説明変数 038 * @param in_y 目的変数 039 * @param noIntercept 切片利用有無(trueで利用しない) 040 */ 041 private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 042 this.noIntercept = noIntercept; 043 044 // ここで重回帰計算 045 final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 046 regression.setNoIntercept(noIntercept); 047 regression.newSampleData(in_y, in_x); 048 049 cnst = regression.estimateRegressionParameters(); 050 rsquare = regression.calculateRSquared(); 051 } 052 053 /** 054 * 係数をセットした配列を返します。 055 * 056 * @return 係数の配列 057 */ 058 @Override 059 public double[] getCoefficient() { 060 return Arrays.copyOf( cnst,cnst.length ); 061 } 062 063 /** 064 * 決定係数の取得。 065 * @return 決定係数 066 */ 067 @Override 068 public double getRSquare() { 069 return rsquare; 070 } 071 072 /** 073 * 計算( c0 + c1x1...)を行う。 074 * noInterceptによってc0の利用を決める。 075 * xの大きさが足りない場合は0を返す。 076 * 077 * @param in_x 必要な大きさの変数配列 078 * @return 計算結果 079 */ 080 @Override 081 public double predict( final double... in_x ) { 082 double rtn = 0; 083 final int itr = noIntercept ? 0 : 1; 084 if( in_x.length < cnst.length-itr ) { 085 return rtn; 086 } 087 088 for( int i=0; i < in_x.length; i++ ) { 089 rtn = rtn + in_x[i] * cnst[i+itr]; 090 } 091 if( !noIntercept ) { rtn = rtn + cnst[0]; } 092 093 return rtn; 094 } 095 096 //************** ここまでが本体 ************** 097 /** 098 * ここからテスト用mainメソッド 。 099 * 100 * @param args 引数 101 */ 102 public static void main( final String[] args ) { 103 // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより 104 final double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 }; 105 double[][] x = new double[10][]; 106 x[0] = new double[] { 165, 65 }; 107 x[1] = new double[] { 170, 68 }; 108 x[2] = new double[] { 172, 70 }; 109 x[3] = new double[] { 175, 65 }; 110 x[4] = new double[] { 170, 80 }; 111 x[5] = new double[] { 172, 85 }; 112 x[6] = new double[] { 183, 78 }; 113 x[7] = new double[] { 187, 79 }; 114 x[8] = new double[] { 180, 95 }; 115 x[9] = new double[] { 185, 97 }; 116 117 final HybsMultiRegression mr = new HybsMultiRegression(x,y,true); 118 119 System.out.println( mr.getRSquare() ); 120 System.out.println( Arrays.toString( mr.getCoefficient()) ); 121 122 System.out.println( mr.predict( new double[] { 169,85 } )); 123 } 124} 125