001package org.opengion.penguin.math.statistics;
002
003import java.util.Arrays;
004
005/**
006 * 独自実装の二次回帰計算クラスです。
007 * f(x) = c1x^2 + c2x + c3
008 * の曲線を求めます。
009 */
010public class HybsSquadraticRegression implements HybsRegression {
011        private final double[] cnst = new double[3] ;           // 係数(0次、1次、2次)
012        private double rsquare;         // 決定係数 今のところ求めていない
013
014        /**
015         * コンストラクタ。
016         * 与えた二次元データを元に二次回帰を計算します。
017         *
018         * @param data xとyの組み合わせの配列
019         */
020        public HybsSquadraticRegression( final double[][] data ) {
021                //二次回帰曲線を求めるが、これはapacheにはなさそうなので自前で計算する。
022                train( data );
023        }
024
025        /**
026         * 係数計算
027         * 
028         *      c3Σ+c2Σx+c1Σx^2=Σy
029         *      c3Σx+c2Σ(x^2)+c1Σx^3=Σ(xy)
030         *      c3Σ(x^2)+c2Σ(x^3)+c1Σ(x^4)=Σ(x^2*y)
031         *      この三元連立方程式を解くことになる。
032         *
033         * @param data x,yの配列
034         */
035        private void train( final double[][] data ) {
036                // xの二乗等の総和用
037                final int data_n=data.length;;
038                double sumx2    = 0;
039                double sumx             = 0;
040                double sumxy    = 0;
041                double sumy             = 0;
042                double sumx3    = 0;
043                double sumx2y   = 0;
044                double sumx4    = 0;
045
046                // まずは計算に使うための和を計算
047                for( int i=0; i < data_n; i++ ) {
048                        final double data_x     = data[i][0];
049                        final double data_y     = data[i][1];
050                        final double x2         = data_x*data_x;
051
052                        sumx    += data_x;
053                        sumx2   += x2;
054                        sumxy   += data_x * data_y;
055                        sumy    += data_y;
056                        sumx3   += x2 * data_x;
057                        sumx2y  += x2 * data_y;
058                        sumx4   += x2 * x2;
059                }
060
061                // ガウス・ジョルダン法で係数計算
062                final double diffx2 = sumx2 - sumx * sumx / data_n;
063                final double diffxy = sumxy - sumx * sumy / data_n;
064                final double diffx3 = sumx3 - sumx2 * sumx /data_n;
065                final double diffx2y = sumx2y - sumx2 * sumy /data_n;
066                final double diffx4 = sumx4 - sumx2 * sumx2 /data_n;
067                final double diffd = diffx2 * diffx4 - diffx3 * diffx3;
068
069                cnst[2] = ( diffx2y * diffx2 - diffxy * diffx3 ) / diffd;
070                cnst[1] = ( diffxy * diffx4 - diffx2y * diffx3 ) / diffd;
071                cnst[0] = sumy/data_n - cnst[1]*sumx/ data_n - cnst[2]*sumx2/data_n;
072
073                rsquare = 0;            // 決定係数 今のところ求めていない
074        }
075
076        /**
077         * 決定係数の取得。
078         * @return 決定係数
079         */
080        @Override
081        public double getRSquare() {
082                return rsquare;
083        }
084
085        /**
086         * 係数(0次、1次、2次)の順にセットした配列を返します。
087         *
088         * @return 係数の配列
089         */
090        @Override
091        public double[] getCoefficient() {
092                return Arrays.copyOf( cnst,cnst.length );
093        }
094
095        /**
096         * c2*x^2 + c1*x + c0を計算。
097         *
098         * @param in_x 必要な大きさの変数配列
099         * @return 計算結果
100         */
101        @Override
102        public double predict( final double... in_x ) {
103                return cnst[2] * in_x[0] * in_x[0] + cnst[1] * in_x[0] + cnst[0];
104        }
105
106        //************** ここまでが本体 **************
107        /**
108         * ここからテスト用mainメソッド 。
109         *
110         * @param args 引数
111         */
112        public static void main( final String[] args ) {
113                final double[][] data = {{1, 2.3}, {2, 5.1}, {3, 9.1}, {4, 16.2}}; 
114
115                final HybsSquadraticRegression sr = new HybsSquadraticRegression(data);
116
117                final double[] cnst = sr.getCoefficient();
118
119                System.out.println(cnst[2]);
120                System.out.println(cnst[1]);
121                System.out.println(cnst[0]);
122
123                System.out.println(sr.predict( 5 ));
124        }
125}
126