001package org.opengion.penguin.math.statistics;
002
003/**
004 * 独自実装の二次回帰計算クラスです。
005 * f(x) = c1x^2 + c2x + c3
006 * の曲線を求めます。
007 */
008public class HybsSquadraticRegression implements HybsRegression{
009        private double c1;              // 2次の係数
010        private double c2;              // 1次の係数
011        private double c3;              // 0次の係数
012        private double rsquare;         // 決定係数 今のところ求めていない
013                
014        /**
015         * コンストラクタ。
016         * 与えた二次元データを元に二次回帰を計算します。
017         * @param data xとyの組み合わせの配列
018         */
019        public HybsSquadraticRegression(final double[][] data){
020                //二次回帰曲線を求めるが、これはapacheにはなさそうなので自前で計算する。
021                train(data);
022        }
023        
024        /**
025         * コンストラクタ。
026         * 係数を与えます。
027         * (以前に計算したものを利用)
028         * @param c1 2次の係数
029         * @param c2 1次の係数
030         * @param c3 0次の係数
031         * 
032         */
033        public HybsSquadraticRegression( final double c1, final double c2, final double c3 ){
034                this.c1 = c1;
035                this.c2 = c2;
036                this.c3 = c3;
037        }
038        
039        /**
040         * コンストラクタ
041         * このコンストラクタを利用した場合はtrainを実施して学習するか、setCoefficientで係数をセットする。
042         */
043        public HybsSquadraticRegression(){
044                //何もしない
045        }
046        
047        /**
048         * 係数計算
049         * 
050         * 
051         *      c3Σ+c2Σx+c1Σx^2=Σy
052         *      c3Σx+c2Σ(x^2)+c1Σx^3=Σ(xy)
053         *      c3Σ(x^2)+c2Σ(x^3)+c1Σ(x^4)=Σ(x^2*y)
054         *      この三元連立方程式を解くことになる。
055         *
056         * @param data x,yの配列
057         */
058        public void train( final double[][] data ){
059                // xの二乗等の総和用
060                int data_n=data.length;;
061                double data_x=0;
062                double data_y=0;
063                double x2=0;
064                double sumx2=0;
065                double sumx=0;
066                double sumxy=0;
067                double sumy=0;
068                double sumx3=0;
069                double sumx2y=0;
070                double sumx4=0;
071                
072                // まずは計算に使うための和を計算
073                for( int i=0; i < data_n; i++ ){
074                        data_x = data[i][0];
075                        data_y = data[i][1];
076                        x2 = data_x*data_x;
077                        
078                        sumx    += data_x;
079                        sumx2   += x2;
080                        sumxy   += data_x * data_y;
081                        sumy    += data_y;
082                        sumx3   += x2 * data_x;
083                        sumx2y  += x2 * data_y;
084                        sumx4   += x2 * x2;
085                }
086                
087                // ガウス・ジョルダン法で係数計算
088                double diffx2 = sumx2 - sumx * sumx / data_n;
089                double diffxy = sumxy - sumx * sumy / data_n;
090                double diffx3 = sumx3 - sumx2 * sumx /data_n;
091                double diffx2y = sumx2y - sumx2 * sumy /data_n;
092                double diffx4 = sumx4 - sumx2 * sumx2 /data_n;
093                double diffd = diffx2 * diffx4 - diffx3 * diffx3;
094                c1 = ( diffx2y * diffx2 - diffxy * diffx3 ) / diffd;
095                c2 = ( diffxy * diffx4 - diffx2y * diffx3 ) / diffd;
096                c3 = sumy/data_n - c2*sumx/ data_n - c1*sumx2/data_n;
097                
098        }
099        
100        /**
101         * このクラスでは未使用。
102         * 
103         * @param opt オプション
104         */
105        public void setOption(final double[] opt){
106                // 特にオプションなし
107        }
108        
109        /**
110         * 係数c1の取得。
111         * @return 係数c1
112         */
113        public double getC1(){
114                return c1;
115        }
116        
117        /**
118         * 係数c2の取得。
119         * @return 係数c2
120         */
121        public double getC2(){
122                return c2;
123        }
124        
125        /**
126         * 係数c3取得。
127         * @return 係数c3
128         */
129        public double getC3(){
130                return c3;
131        }
132        
133        /**
134         * c1,c2,c3の順にセットした配列を返します。
135         * @return 係数の配列
136         */
137        public double[] getCoefficient(){
138                double[] rtn = {c1,c2,c3};
139                return rtn;
140        }
141        
142        /**
143         * 決定係数の取得。
144         * @return 決定係数
145         */
146        public double getRSquare(){
147                return rsquare;
148        }
149        
150        /**
151         * c1,c2,c3の順に配列の内容をセットします。
152         * 
153         * @param in_c 係数配列
154         */
155        public void setCoefficient(final double[] in_c){
156                c1 = in_c[0];
157                c2 = in_c[1];
158                c3 = in_c[2];
159        }
160        
161        /**
162         * c1*x^2 + c2*x + c3を計算
163         * @param in_x 与えるx
164         * @return 計算結果
165         */
166        public double predict(final double... in_x){
167                return c1 * in_x[0] * in_x[0] + c2 * in_x[0] + c3;
168        }
169
170        /*** ここまでが本体 ***/
171        /*** ここからテスト用mainメソッド ***/
172        /**
173         * @param args *****************************************/
174        public static void main(final String [] args) {
175                double[][] data = {{1, 2.3}, {2, 5.1}, {3, 9.1}, {4, 16.2}}; 
176                
177                
178                HybsSquadraticRegression sr = new HybsSquadraticRegression(data);
179                System.out.println(sr.getC1());
180                System.out.println(sr.getC2());
181                System.out.println(sr.getC3());
182                
183                System.out.println(sr.predict( 5 ));
184        }
185}
186