Updated Intra2.java to allow for changing weights, MAX_CG ceiling
authorNathan E. Egge <negge@dgql.org>
Mon, 1 Oct 2012 22:23:03 +0000 (18:23 -0400)
committerNathan E. Egge <negge@dgql.org>
Mon, 1 Oct 2012 22:23:03 +0000 (18:23 -0400)
tools/java/src/intra/Intra2.java

index 20282f7..c0a812e 100644 (file)
@@ -27,8 +27,16 @@ public class Intra2 {
 
        public static final int STEPS=30;
 
+       public static final int DC_WEIGHT=10;
+
+       public static final int BITS_PER_COEFF=6;
+
+       public static final double MAX_CG=BITS_PER_COEFF*-10*Math.log10(0.5);
+
        public static final boolean USE_CG=true;
 
+       public static final boolean UPDATE_WEIGHT=false;
+
        public static final int[] MODE_COLORS={
                0xFF000000,
                0xFFFFFFFF,
@@ -135,15 +143,15 @@ public class Intra2 {
                protected double cgPerCoeff(double[] _mse) {
                        double total_cg=0;
                        for (int j=0;j<B_SZ*B_SZ;j++) {
-                               total_cg+=-10*Math.log10(_mse[j]/(covariance[3*B_SZ*B_SZ+j][3*B_SZ*B_SZ+j]/weightTotal));
+                               double cg=-10*Math.log10(_mse[j]/(covariance[3*B_SZ*B_SZ+j][3*B_SZ*B_SZ+j]/weightTotal));
+                               if (cg>=MAX_CG) {
+                                       cg=MAX_CG;
+                               }
+                               total_cg+=cg;
                        }
                        return(total_cg/(B_SZ*B_SZ));
                }
 
-               protected void printStats() {
-                       System.out.println("  "+numBlocks+"\t"+msePerCoeff(mse)+"\t"+cgPerCoeff(mse));
-               }
-
                protected double predError(int[] _data,int _y) {
                        double pred=beta_0[_y];
                        for (int i=0;i<3*B_SZ*B_SZ;i++) {
@@ -152,7 +160,7 @@ public class Intra2 {
                        return(Math.abs(_data[INDEX[3*B_SZ*B_SZ+_y]]-pred));
                }
 
-               protected void mseUpdateHelper(int[] _data,int _weight,double[] _mse) {
+               protected void mseUpdateHelper(int[] _data,double _weight,double[] _mse) {
                        for (int j=0;j<B_SZ*B_SZ;j++) {
                                double e=predError(_data,j);
                                double se=e*e;
@@ -162,18 +170,18 @@ public class Intra2 {
                        }
                }
 
-               protected void mseUpdate(int[] _data,int _weight) {
+               protected void mseUpdate(int[] _data,double _weight) {
                        mseUpdateHelper(_data,_weight,mse);
                }
 
-               protected double deltaBits(int[] _data,int _weight) {
+               protected double deltaBits(int[] _data,double _weight) {
                        double old_cg=cgPerCoeff(mse);
                        double[] mse=new double[B_SZ*B_SZ];
                        mseUpdateHelper(_data,_weight,mse);
                        update(_weight,_data);
                        double cg=cgPerCoeff(mse);
                        update(-_weight,_data);
-                       return(cg-old_cg);
+                       return((cg-old_cg)*(weightTotal+_weight));
                }
 
        };
@@ -244,13 +252,13 @@ public class Intra2 {
                                        int weight=dis.readShort();
                                        //System.out.println("mode="+mode+" weight="+weight);
                                        dos.writeShort(mode);
-                                       //dos.writeShort(weight);
+                                       dos.writeDouble(weight);
                                        int[] data=new int[2*B_SZ*2*B_SZ];
                                        for (int i=0;i<2*B_SZ*2*B_SZ;i++) {
                                                data[i]=dis.readShort();
                                        }
                                        if (weight>0) {
-                                               modeData[mode].addBlock(mode==0?1:weight,data);
+                                               modeData[mode].addBlock(mode==0?DC_WEIGHT:weight,data);
                                                //modeData[mode].addBlock(weight,data);
                                        }
                                        rgb[block]=MODE_COLORS[mode];
@@ -267,6 +275,39 @@ public class Intra2 {
                return(modeData);
        }
 
+       protected void fitData(ModeData[] _modeData) {
+               // compute betas and MSE
+               for (int i=0;i<MODES;i++) {
+                       /*System.out.println("mode "+i);
+                       double[] old_mse=new double[B_SZ*B_SZ];
+                       for (int j=0;j<B_SZ*B_SZ;j++) {
+                               old_mse[j]=_modeData[i].mse[j];
+                       }*/
+                       _modeData[i].computeBetas();
+                       /*for (int j=0;j<B_SZ*B_SZ;j++) {
+                               System.out.println("  "+j+": "+old_mse[j]+"\t"+_modeData[i].mse[j]+"\t"+(_modeData[i].mse[j]-old_mse[j]));
+                       }*/
+               }
+       }
+
+       protected static final int SPACE=4;
+
+       protected void printStats(ModeData[] _modeData) {
+               double mse_sum=0;
+               double cg_sum=0;
+               double weight=0;
+               for (int i=0;i<MODES;i++) {
+                       double mse=_modeData[i].msePerCoeff(_modeData[i].mse);
+                       double cg=_modeData[i].cgPerCoeff(_modeData[i].mse);
+                       System.out.println("  "+i+": "+_modeData[i].numBlocks+"\t"+_modeData[i].weightTotal+"\t"+mse+"\t"+cg);
+                       mse_sum+=_modeData[i].weightTotal*mse;
+                       cg_sum+=_modeData[i].weightTotal*cg;
+                       weight+=_modeData[i].weightTotal;
+               }
+               System.out.println("Average MSE "+mse_sum/weight);
+               System.out.println("Average CG  "+cg_sum/weight);
+       }
+
        protected void processData(int step,ModeData[] _modeData,File[] _files) throws IOException {
                for (File file : _files) {
                        System.out.println("Processing "+file.getPath());
@@ -280,14 +321,19 @@ public class Intra2 {
 
                        int[] rgb=new int[nx*ny];
                        for (int block=0;block<nx*ny;block++) {
+                               if ((block&0x3fff)==0) {
+                                       System.out.println(block);
+                               }
+
                                // load the data
                                int mode=dis.readShort();
-                               int weight=dis.readShort();
+                               double weight=dis.readShort();
                                int[] data=new int[2*B_SZ*2*B_SZ];
                                for (int i=0;i<2*B_SZ*2*B_SZ;i++) {
                                        data[i]=dis.readShort();
                                }
                                int lastMode=dis2.readShort();
+                               double lastWeight=dis2.readDouble();
                                if (weight>0) {
                                        // compute error
                                        double[] error=new double[MODES];
@@ -297,7 +343,7 @@ public class Intra2 {
                                                        error[i]+=_modeData[i].predError(data,j);
                                                }
                                                if (USE_CG) {
-                                                       cg[i]=_modeData[i].deltaBits(data,(lastMode==i?-1:1)*(i==0?1:weight));
+                                                       cg[i]=_modeData[i].deltaBits(data,(lastMode==i?-1:1)*(i==0?DC_WEIGHT:weight));
                                                }
                                        }
                                        double best=-Double.MAX_VALUE;
@@ -320,14 +366,20 @@ public class Intra2 {
                                                        }
                                                }
                                        }
-                                       if (mode!=lastMode) {
-                                               _modeData[lastMode].mseUpdate(data,lastMode==0?-1:-weight);
-                                               _modeData[mode].mseUpdate(data,mode==0?1:weight);
-                                               _modeData[lastMode].removeBlock(lastMode==0?1:weight,data);
-                                               _modeData[mode].addBlock(mode==0?1:weight,data);
+                                       if (UPDATE_WEIGHT) {
+                                               weight=best-nextBest;
+                                       }
+                                       if (USE_CG) {
+                                               _modeData[lastMode].mseUpdate(data,lastMode==0?-DC_WEIGHT:-lastWeight);
+                                               _modeData[mode].mseUpdate(data,mode==0?DC_WEIGHT:weight);
                                        }
+                                       _modeData[lastMode].removeBlock(lastMode==0?DC_WEIGHT:lastWeight,data);
+                                       _modeData[lastMode].computeBetas();
+                                       _modeData[mode].addBlock(mode==0?DC_WEIGHT:weight,data);
+                                       _modeData[mode].computeBetas();
                                }
                                dos.writeShort(mode);
+                               dos.writeDouble(weight);
                                rgb[block]=MODE_COLORS[mode];
                        }
 
@@ -344,7 +396,9 @@ public class Intra2 {
        public void run() throws Exception {
                File[] files=getFiles();
                if (files==null) {
-                       System.out.println("no data files in "+DATA_FOLDER);
+                       System.out.println("No .coeffs files found in "+DATA_FOLDER+" folder.  Enable the PRINT_BLOCKS ifdef and run:");
+                       System.out.println("  for f in subset1-y4m/*.y4m; do ./init_intra_maps $f && ./init_intra_xform $f 2> $f.coeffs; done");
+                       System.out.println("Move the *.coeffs files to "+DATA_FOLDER);
                        return;
                }
 
@@ -353,11 +407,11 @@ public class Intra2 {
 
                long start=System.currentTimeMillis();
                for (int k=1;k<=STEPS;k++) {
-                       // compute betas and MSE
-                       for (int i=0;i<MODES;i++) {
-                               modeData[i].computeBetas();
-                               modeData[i].printStats();
-                       }
+
+                       // update model
+                       fitData(modeData);
+
+                       printStats(modeData);
 
                        // reclassify blocks
                        processData(k,modeData,files);