optimisations: Another bunch of simplifications to alg_quant(), mainly to
authorJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Tue, 25 Mar 2008 10:28:40 +0000 (21:28 +1100)
committerJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Tue, 25 Mar 2008 10:28:40 +0000 (21:28 +1100)
remove unnecessary copying and some conditional branches.

libcelt/mathops.h
libcelt/vq.c

index d67db37..965deef 100644 (file)
@@ -54,6 +54,23 @@ static inline int find_max16(celt_word16_t *x, int len)
 }
 #endif
 
+#ifndef OVERRIDE_FIND_MAX32
+static inline int find_max32(celt_word32_t *x, int len)
+{
+   celt_word32_t max_corr=-VERY_LARGE16;
+   int i, id = 0;
+   for (i=0;i<len;i++)
+   {
+      if (x[i] > max_corr)
+      {
+         id = i;
+         max_corr = x[i];
+      }
+   }
+   return id;
+}
+#endif
+
 
 #ifndef FIXED_POINT
 
index f6389d2..7407151 100644 (file)
@@ -93,30 +93,19 @@ static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const
    RESTORE_STACK;
 }
 
-/** All the info necessary to keep track of a hypothesis during the search */
-struct NBest {
-   celt_word32_t score;
-   int sign;
-   int pos;
-   celt_word32_t xy;
-   celt_word32_t yy;
-   celt_word32_t yp;
-};
 
 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
 {
-   VARDECL(celt_norm_t, _y);
-   VARDECL(celt_norm_t, _ny);
-   VARDECL(int, _iy);
-   VARDECL(int, _iny);
+   VARDECL(celt_norm_t, y);
+   VARDECL(int, iy);
    VARDECL(int, signx);
-   celt_norm_t *y, *ny;
-   int *iy, *iny;
-   int i, j;
+   VARDECL(celt_word32_t, scores);
+   int i, j, is;
+   celt_word16_t s;
    int pulsesLeft;
+   celt_word32_t sum;
    celt_word32_t xy, yy, yp;
-   struct NBest nbest;
-   celt_word32_t Rpp=0, Rxp=0;
+   celt_word16_t Rpp;
 #ifdef FIXED_POINT
    int yshift;
 #endif
@@ -126,17 +115,11 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
    yshift = 14-EC_ILOG(K);
 #endif
 
-   ALLOC(_y, N, celt_norm_t);
-   ALLOC(_ny, N, celt_norm_t);
-   ALLOC(_iy, N, int);
-   ALLOC(_iny, N, int);
+   ALLOC(y, N, celt_norm_t);
+   ALLOC(iy, N, int);
    ALLOC(signx, N, int);
+   ALLOC(scores, N, celt_word32_t);
 
-   y = _y;
-   ny = _ny;
-   iy = _iy;
-   iny = _iny;
-   
    for (j=0;j<N;j++)
    {
       if (X[j]>0)
@@ -145,13 +128,12 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
          signx[j]=-1;
    }
    
+   sum = 0;
    for (j=0;j<N;j++)
    {
-      Rpp = MAC16_16(Rpp, P[j],P[j]);
-      Rxp = MAC16_16(Rxp, X[j],P[j]);
+      sum = MAC16_16(sum, P[j],P[j]);
    }
-   Rpp = ROUND16(Rpp, NORM_SHIFT);
-   Rxp = ROUND16(Rxp, NORM_SHIFT);
+   Rpp = ROUND16(sum, NORM_SHIFT);
 
    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
 
@@ -165,6 +147,9 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
    while (pulsesLeft > 0)
    {
       int pulsesAtOnce=1;
+      int sign;
+      celt_word32_t Rxy, Ryy, Ryp;
+      celt_word32_t g;
       
       /* Decide on how many pulses to find at once */
       pulsesAtOnce = pulsesLeft/N;
@@ -172,31 +157,31 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
          pulsesAtOnce = 1;
       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
 
-      nbest.score = -VERY_LARGE32;
-
-      for (j=0;j<N;j++)
+      /* Choose between fast and accurate strategy depending on where we are in the search */
+      if (pulsesLeft>1)
       {
-         int sign;
-         /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
-         celt_word32_t Rxy, Ryy, Ryp;
-         celt_word32_t score;
-         celt_word32_t g;
-         celt_word16_t s;
-         
-         /* Select sign based on X[j] alone */
-         sign = signx[j];
-         s = SHL16(sign*pulsesAtOnce, yshift);
-
-         /* Updating the sums of the new pulse(s) */
-         Rxy = xy + MULT16_16(s,X[j]);
-         Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
-         Ryp = yp + MULT16_16(s, P[j]);
-         
-         if (pulsesLeft>1)
+         for (j=0;j<N;j++)
          {
-            score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
-         } else
+            /* Select sign based on X[j] alone */
+            sign = signx[j];
+            s = SHL16(sign*pulsesAtOnce, yshift);
+            /* Temporary sums of the new pulse(s) */
+            Rxy = xy + MULT16_16(s,X[j]);
+            Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+            Ryp = yp + MULT16_16(s, P[j]);
+            scores[j] = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
+         }
+      } else {
+         for (j=0;j<N;j++)
          {
+            /* Select sign based on X[j] alone */
+            sign = signx[j];
+            s = SHL16(sign*pulsesAtOnce, yshift);
+            /* Temporary sums of the new pulse(s) */
+            Rxy = xy + MULT16_16(s,X[j]);
+            Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+            Ryp = yp + MULT16_16(s, P[j]);
+
             /* Compute the gain such that ||p + g*y|| = 1 */
             g = MULT16_32_Q15(
                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
@@ -206,54 +191,23 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
             /* Knowing that gain, what's the error: (x-g*y)^2 
                (result is negated and we discard x^2 because it's constant) */
             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
-            score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
+            scores[j] = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
                     - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
          }
-         
-         if (score>nbest.score)
-         {
-            nbest.score = score;
-            nbest.pos = j;
-            nbest.sign = sign;
-            nbest.xy = Rxy;
-            nbest.yy = Ryy;
-            nbest.yp = Ryp;
-         }
-      }
-
-      celt_assert2(nbest.score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
-
-      /* Only now that we've made the final choice, update ny/iny and others */
-      {
-         int n;
-         int is;
-         celt_norm_t s;
-         is = nbest.sign*pulsesAtOnce;
-         s = SHL16(is, yshift);
-         for (n=0;n<N;n++)
-            ny[n] = y[n];
-         ny[nbest.pos] += s;
-
-         for (n=0;n<N;n++)
-            iny[n] = iy[n];
-         iny[nbest.pos] += is;
-
-         xy = nbest.xy;
-         yy = nbest.yy;
-         yp = nbest.yp;
-      }
-      /* Swap ny/iny with y/iy */
-      {
-         celt_norm_t *tmp_ny;
-         int *tmp_iny;
-
-         tmp_ny = ny;
-         ny = y;
-         y = tmp_ny;
-         tmp_iny = iny;
-         iny = iy;
-         iy = tmp_iny;
       }
+      
+      j = find_max32(scores, N);
+      is = signx[j]*pulsesAtOnce;
+      s = SHL16(is, yshift);
+
+      /* Updating the sums of the new pulse(s) */
+      xy = xy + MULT16_16(s,X[j]);
+      yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+      yp = yp + MULT16_16(s, P[j]);
+
+      /* Only now that we've made the final choice, update y/iy */
+      y[j] += s;
+      iy[j] += is;
       pulsesLeft -= pulsesAtOnce;
    }