001/* 002 * Copyright (c) 2009 The openGion Project. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 013 * either express or implied. See the License for the specific language 014 * governing permissions and limitations under the License. 015 */ 016package org.opengion.penguin.math.statistics; 017 018import java.util.Arrays; 019 020import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; 021 022/** 023 * apache.commons.mathを利用したOLS重回帰計算のクラスです。 024 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。 025 * c0の切片を考慮するかどうかはnoInterceptで決めます。 026 * 027 */ 028public class HybsMultiRegression implements HybsRegression { 029 private double cnst[]; // 各係数(xの種類+1になる?) 030 private double rsquare; // 決定係数 031 private boolean noIntercept; //切片を利用するかどうか 032 033 /** 034 * コンストラクタ。 035 * 与えた二次元データを元に重回帰を計算します。 036 * xデータとして二次元配列を与えます。 037 * noInterceptで切片有り無しを選択します。 038 * 039 * @param in_x 説明変数 040 * @param in_y 目的変数 041 * @param noIntercept 切片利用有無(trueで利用しない) 042 */ 043 public HybsMultiRegression( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 044 train( in_x, in_y, noIntercept ); 045 } 046 047 /** 048 * 与えた二次元データを元に重回帰を計算します。 049 * xデータとして二次元配列を与えます。 050 * noInterceptで切片有り無しを選択します。 051 * 052 * @param in_x 説明変数 053 * @param in_y 目的変数 054 * @param noIntercept 切片利用有無(trueで利用しない) 055 */ 056 private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) { 057 this.noIntercept = noIntercept; 058 059 // ここで重回帰計算 060 final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); 061 regression.setNoIntercept(noIntercept); 062 regression.newSampleData(in_y, in_x); 063 064 cnst = regression.estimateRegressionParameters(); 065 rsquare = regression.calculateRSquared(); 066 } 067 068 /** 069 * 係数をセットした配列を返します。 070 * 071 * @return 係数の配列 072 */ 073 @Override 074 public double[] getCoefficient() { 075 return Arrays.copyOf( cnst,cnst.length ); 076 } 077 078 /** 079 * 決定係数の取得。 080 * @return 決定係数 081 */ 082 @Override 083 public double getRSquare() { 084 return rsquare; 085 } 086 087 /** 088 * 計算( c0 + c1x1...)を行う。 089 * noInterceptによってc0の利用を決める。 090 * xの大きさが足りない場合は0を返す。 091 * 092 * @param in_x 必要な大きさの変数配列 093 * @return 計算結果 094 */ 095 @Override 096 public double predict( final double... in_x ) { 097 double rtn = 0; 098 final int itr = noIntercept ? 0 : 1; 099 if( in_x.length < cnst.length-itr ) { 100 return rtn; 101 } 102 103 for( int i=0; i < in_x.length; i++ ) { 104 rtn = rtn + in_x[i] * cnst[i+itr]; 105 } 106 if( !noIntercept ) { rtn = rtn + cnst[0]; } 107 108 return rtn; 109 } 110 111 // ================ ここまでが本体 ================ 112 113 /** 114 * ここからテスト用mainメソッド 。 115 * 116 * @param args 引数 117 */ 118 public static void main( final String[] args ) { 119 // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより 120 final double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 }; 121 double[][] x = new double[10][]; 122 x[0] = new double[] { 165, 65 }; 123 x[1] = new double[] { 170, 68 }; 124 x[2] = new double[] { 172, 70 }; 125 x[3] = new double[] { 175, 65 }; 126 x[4] = new double[] { 170, 80 }; 127 x[5] = new double[] { 172, 85 }; 128 x[6] = new double[] { 183, 78 }; 129 x[7] = new double[] { 187, 79 }; 130 x[8] = new double[] { 180, 95 }; 131 x[9] = new double[] { 185, 97 }; 132 133 final HybsMultiRegression mr = new HybsMultiRegression(x,y,true); 134 135 System.out.println( mr.getRSquare() ); 136 System.out.println( Arrays.toString( mr.getCoefficient()) ); 137 138 System.out.println( mr.predict( new double[] { 169,85 } )); 139 } 140} 141