optimisation: The "simple" Rxy/sqrt(Ryy) case in alg_quant no longer requires
authorJean-Marc Valin <Jean-Marc.Valin@csiro.au>
Tue, 25 Mar 2008 23:34:23 +0000 (10:34 +1100)
committerJean-Marc Valin <Jean-Marc.Valin@csiro.au>
Tue, 25 Mar 2008 23:34:23 +0000 (10:34 +1100)
a division

libcelt/vq.c
libcelt/vq.h

index 3a6494e..cb16c69 100644 (file)
@@ -1,4 +1,4 @@
-/* (C) 2007 Jean-Marc Valin, CSIRO
+/* (C) 2007-2008 Jean-Marc Valin, CSIRO
 */
 /*
    Redistribution and use in source and binary forms, with or without
@@ -39,9 +39,8 @@
 #include "arch.h"
 #include "os_support.h"
 
-/** Takes the pitch vector and the decoded residual vector (non-compressed), 
-   applies the compression in the pitch direction, computes the gain that will
-   give ||p+g*y||=1 and mixes the residual with the pitch. */
+/** Takes the pitch vector and the decoded residual vector, computes the gain
+    that will give ||p+g*y||=1 and mixes the residual with the pitch. */
 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
 {
    int i;
@@ -99,7 +98,6 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
    VARDECL(celt_norm_t, y);
    VARDECL(int, iy);
    VARDECL(int, signx);
-   VARDECL(celt_word32_t, scores);
    int i, j, is;
    celt_word16_t s;
    int pulsesLeft;
@@ -118,7 +116,6 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
    ALLOC(y, N, celt_norm_t);
    ALLOC(iy, N, int);
    ALLOC(signx, N, int);
-   ALLOC(scores, N, celt_word32_t);
 
    for (j=0;j<N;j++)
    {
@@ -150,29 +147,49 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
       int sign;
       celt_word32_t Rxy, Ryy, Ryp;
       celt_word32_t g;
+      celt_word32_t best_num;
+      celt_word16_t best_den;
+      int best_id;
       
       /* Decide on how many pulses to find at once */
       pulsesAtOnce = pulsesLeft/N;
       if (pulsesAtOnce<1)
          pulsesAtOnce = 1;
 
+      /* This should ensure that anything we can process will have a better score */
+      best_num = -SHR32(VERY_LARGE32,4);
+      best_den = 0;
+      best_id = 0;
       /* Choose between fast and accurate strategy depending on where we are in the search */
       if (pulsesLeft>1)
       {
          for (j=0;j<N;j++)
          {
+            celt_word32_t num;
+            celt_word16_t den;
             /* Select sign based on X[j] alone */
             sign = signx[j];
             s = SHL16(sign*pulsesAtOnce, yshift);
             /* Temporary sums of the new pulse(s) */
             Rxy = xy + MULT16_16(s,X[j]);
             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
-            /* This score is approximate, but good enough for the first pulses */
-            scores[j] = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
+            
+            /* Approximate score: we maximise Rxy/sqrt(Ryy) */
+            num = MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14)));
+            den = ROUND16(Ryy,14);
+            /* The idea is to check for num/den >= best_num/best_den, but that way
+               we can do it without any division */
+            if (MULT16_32_Q15(best_den, num) >= MULT16_32_Q15(den, best_num))
+            {
+               best_den = den;
+               best_num = num;
+               best_id = j;
+            }
          }
       } else {
          for (j=0;j<N;j++)
          {
+            celt_word32_t num;
             /* Select sign based on X[j] alone */
             sign = signx[j];
             s = SHL16(sign*pulsesAtOnce, yshift);
@@ -190,12 +207,17 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
             /* Knowing that gain, what's the error: (x-g*y)^2 
                (result is negated and we discard x^2 because it's constant) */
             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
-            scores[j] = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
-                    - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
+            num = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
+                  - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
+            if (num >= best_num)
+            {
+               best_num = num;
+               best_id = j;
+            } 
          }
       }
       
-      j = find_max32(scores, N);
+      j = best_id;
       is = signx[j]*pulsesAtOnce;
       s = SHL16(is, yshift);
 
index f5c507d..8e50886 100644 (file)
@@ -1,4 +1,4 @@
-/* (C) 2007 Jean-Marc Valin, CSIRO
+/* (C) 2007-2008 Jean-Marc Valin, CSIRO
 */
 /**
    @file vq.h
@@ -48,7 +48,6 @@
  * @param N Number of samples to encode
  * @param K Number of pulses to use
  * @param p Pitch vector (it is assumed that p+x is a unit vector)
- * @param alpha compression factor to apply in the pitch direction (magic!)
  * @param enc Entropy encoder state
 */
 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc);
@@ -58,7 +57,6 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
  * @param N Number of samples to decode
  * @param K Number of pulses to use
  * @param p Pitch vector (automatically added to x)
- * @param alpha compression factor in the pitch direction (magic!)
  * @param dec Entropy decoder state
  */
 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec);