Getting rid of the DCT in od_compute_dist_8x8()
authorJean-Marc Valin <jmvalin@jmvalin.ca>
Wed, 18 Jan 2017 06:58:33 +0000 (01:58 -0500)
committerJean-Marc Valin <jmvalin@jmvalin.ca>
Wed, 18 Jan 2017 20:54:34 +0000 (15:54 -0500)
Replacing the DCT and frequency weighting by a filter

   PSNR | PSNR Cb | PSNR Cr | PSNR HVS |   SSIM | MS SSIM | CIEDE 2000
-0.0995 | -0.9381 | -0.9595 |  -1.1745 | 0.5285 |  0.1386 |    -0.1813

src/encode.c

index 42e5b7c..25f4c43 100644 (file)
@@ -1100,10 +1100,15 @@ static int od_compute_var_4x4(od_coeff *x, int stride) {
   return (s2 - (sum*sum >> 4));
 }
 
+/* OD_DIST_LP_MID controls the frequency weighting filter used for computing
+   the distortion. For a value X, the filter is [1 X 1]/(X + 2) and
+   is applied both horizontally and vertically. For X=5, the filter is
+   a good approximation for the OD_QM8_Q4_HVS quantization matrix. */
+#define OD_DIST_LP_MID (5)
+#define OD_DIST_LP_NORM (OD_DIST_LP_MID + 2)
+
 static double od_compute_dist_8x8(daala_enc_ctx *enc, od_coeff *x, od_coeff *y,
- int stride, int bs) {
-  od_coeff e[8*8];
-  od_coeff et[8*8];
+ od_coeff *e_lp, int stride) {
   double sum;
   int min_var;
   double mean_var;
@@ -1148,29 +1153,21 @@ static double od_compute_dist_8x8(daala_enc_ctx *enc, od_coeff *x, od_coeff *y,
 #else
   activity = 1;
 #endif
-  for (i = 0; i < 8; i++) {
-    for (j = 0; j < 8; j++) e[8*i + j] = x[i*stride + j] - y[i*stride + j];
-  }
-  (*enc->state.opt_vtbl.fdct_2d[OD_BLOCK_8X8])(&et[0], 8, &e[0], 8);
   sum = 0;
   for (i = 0; i < 8; i++) {
     for (j = 0; j < 8; j++) {
-      double mag;
-      mag = 16./OD_QM8_Q4_HVS[i*8 + j];
-      /* We attempt to consider the basis magnitudes here, though that's not
-         perfect for block size 16x16 and above since only some edges are
-         filtered then. */
-      mag *= OD_BASIS_MAG[0][bs][i << (bs - 1)]*
-       OD_BASIS_MAG[0][bs][j << (bs - 1)];
-      mag *= mag;
-      sum += et[8*i + j]*(double)et[8*i + j]*mag;
+      sum += e_lp[i*stride + j]*(double)e_lp[i*stride + j];
     }
   }
+  /* Normalize the filter to unit DC response and add rough compensation for
+     basis magnitude. */
+  sum *= 0.92/(OD_DIST_LP_NORM*OD_DIST_LP_NORM*OD_DIST_LP_NORM*OD_DIST_LP_NORM);
+  /*printf("%f %f\n", sum, sum2);*/
   return activity*activity*(sum + vardist);
 }
 
 static double od_compute_dist(daala_enc_ctx *enc, od_coeff *x, od_coeff *y,
- int n, int bs) {
+ int n) {
   int i;
   double sum;
   sum = 0;
@@ -1182,10 +1179,37 @@ static double od_compute_dist(daala_enc_ctx *enc, od_coeff *x, od_coeff *y,
     }
   }
   else {
+    int j;
+    od_coeff e[OD_BSIZE_MAX*OD_BSIZE_MAX];
+    od_coeff tmp[OD_BSIZE_MAX*OD_BSIZE_MAX];
+    od_coeff e_lp[OD_BSIZE_MAX*OD_BSIZE_MAX];
+    int mid = OD_DIST_LP_MID;
+    for (i = 0; i < n; i++) {
+      for (j = 0; j < n; j++) {
+        e[i*n + j] = x[i*n + j] - y[i*n + j];
+      }
+    }
+    for (i = 0; i < n; i++) {
+      tmp[i*n] = mid*e[i*n] + 2*e[i*n + 1];
+      tmp[i*n + n - 1] = mid*e[i*n + n - 1] + 2*e[i*n + n - 2];
+      for (j = 1; j < n - 1; j++) {
+        tmp[i*n + j] = mid*e[i*n + j] + e[i*n + j - 1] + e[i*n + j + 1];
+      }
+    }
+    for (j = 0; j < n; j++) {
+      e_lp[j] = mid*tmp[j] + 2*tmp[n + j];
+      e_lp[(n - 1)*n + j] = mid*tmp[(n - 1)*n + j] + 2*tmp[(n - 2)*n + j];
+    }
+    for (i = 1; i < n - 1; i++) {
+      for (j = 0; j < n; j++) {
+        e_lp[i*n + j] = mid*tmp[i*n + j] + tmp[(i - 1)*n + j]
+         + tmp[(i + 1)*n + j];
+      }
+    }
     for (i = 0; i < n; i += 8) {
-      int j;
       for (j = 0; j < n; j += 8) {
-        sum += od_compute_dist_8x8(enc, &x[i*n + j], &y[i*n + j], n, bs);
+        sum += od_compute_dist_8x8(enc, &x[i*n + j], &y[i*n + j],
+         &e_lp[i*n + j], n);
       }
     }
     /* Compensate for the fact that the quantization matrix lowers the
@@ -1390,10 +1414,10 @@ static int od_block_encode(daala_enc_ctx *enc, od_mb_enc_ctx *ctx, int bs,
     for (i = 0; i < n; i++) {
       for (j = 0; j < n; j++) c_noskip[n*i + j] = c[bo + i*w + j];
     }
-    dist_noskip = od_compute_dist(enc, c_orig, c_noskip, n, bs);
+    dist_noskip = od_compute_dist(enc, c_orig, c_noskip, n);
     lambda = enc->bs_rdo_lambda;
     rate_noskip = od_ec_enc_tell_frac(&enc->ec) - tell;
-    dist_skip = od_compute_dist(enc, c_orig, mc_orig, n, bs);
+    dist_skip = od_compute_dist(enc, c_orig, mc_orig, n);
     rate_skip = (1 << OD_BITRES)*od_encode_cdf_cost(0,
      enc->state.adapt.skip_cdf[2*bs + (pli != 0)],
      4 + (pli == 0 && bs > 0));
@@ -1768,8 +1792,8 @@ static int od_encode_recursive(daala_enc_ctx *enc, od_mb_enc_ctx *ctx,
         for (j = 0; j < n; j++) split[n*i + j] = ctx->c[bo + i*w + j];
       }
       rate_split = od_ec_enc_tell_frac(&enc->ec) - tell;
-      dist_split = od_compute_dist(enc, c_orig, split, n, bs);
-      dist_nosplit = od_compute_dist(enc, c_orig, nosplit, n, bs);
+      dist_split = od_compute_dist(enc, c_orig, split, n);
+      dist_nosplit = od_compute_dist(enc, c_orig, nosplit, n);
       lambda = enc->bs_rdo_lambda;
       if (skip_split || dist_nosplit + lambda*rate_nosplit < dist_split
        + lambda*rate_split) {
@@ -2750,7 +2774,7 @@ static void od_encode_coefficients(daala_enc_ctx *enc, od_mb_enc_ctx *mbctx,
               out[y*n + x] = output[y*w + x];
             }
           }
-          dist = od_compute_dist(enc, orig, out, n, 3);
+          dist = od_compute_dist(enc, orig, out, n);
           best_dist = dist + enc->dering_lambda*
            od_encode_cdf_cost(0, state->adapt.dering_cdf[c], OD_DERING_LEVELS);
           for (gi = 1; gi < OD_DERING_LEVELS; gi++) {
@@ -2769,7 +2793,7 @@ static void od_encode_coefficients(daala_enc_ctx *enc, od_mb_enc_ctx *mbctx,
                   buf32[y*n + x] = buf[y*n + x];
                 }
               }
-              dist = od_compute_dist(enc, orig, buf32, n, 3)
+              dist = od_compute_dist(enc, orig, buf32, n)
                + enc->dering_lambda*od_encode_cdf_cost(gi,
                state->adapt.dering_cdf[c], OD_DERING_LEVELS);
             }