Removed the "pitch compression" in the residual quantisation. Also, removed
authorJean-Marc Valin <Jean-Marc.Valin@csiro.au>
Tue, 25 Mar 2008 03:15:41 +0000 (14:15 +1100)
committerJean-Marc Valin <Jean-Marc.Valin@csiro.au>
Tue, 25 Mar 2008 03:15:41 +0000 (14:15 +1100)
the more complex "n-best search" and replaced it with a greedy search

libcelt/bands.c
libcelt/vq.c
libcelt/vq.h

index 6b44bb3..c48342e 100644 (file)
@@ -291,7 +291,6 @@ void quant_bands(const CELTMode *m, celt_norm_t *X, celt_norm_t *P, celt_mask_t
 {
    int i, j, B, bits;
    const celt_int16_t *eBands = m->eBands;
-   celt_word16_t alpha;
    VARDECL(celt_norm_t, norm);
    VARDECL(int, pulses);
    VARDECL(int, offsets);
@@ -325,13 +324,10 @@ void quant_bands(const CELTMode *m, celt_norm_t *X, celt_norm_t *P, celt_mask_t
       if (eBands[i] >= m->pitchEnd || q<=0)
       {
          q -= 1;
-         alpha = 0;
          if (q<0)
             intra_fold(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), norm, P+B*eBands[i], B, eBands[i], eBands[m->nbEBands+1]);
          else
             intra_prediction(X+B*eBands[i], W+B*eBands[i], B*(eBands[i+1]-eBands[i]), q, norm, P+B*eBands[i], B, eBands[i], enc);
-      } else {
-         alpha = QCONST16(.7f,15);
       }
       
       if (q > 0)
@@ -339,7 +335,7 @@ void quant_bands(const CELTMode *m, celt_norm_t *X, celt_norm_t *P, celt_mask_t
          int nb_rotations = (B*(eBands[i+1]-eBands[i])+4*q)/(8*q);
          exp_rotation(P+B*eBands[i], B*(eBands[i+1]-eBands[i]), -1, B, nb_rotations);
          exp_rotation(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), -1, B, nb_rotations);
-         alg_quant(X+B*eBands[i], W+B*eBands[i], B*(eBands[i+1]-eBands[i]), q, P+B*eBands[i], alpha, enc);
+         alg_quant(X+B*eBands[i], W+B*eBands[i], B*(eBands[i+1]-eBands[i]), q, P+B*eBands[i], enc);
          exp_rotation(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), 1, B, nb_rotations);
       }
       for (j=B*eBands[i];j<B*eBands[i+1];j++)
@@ -355,7 +351,6 @@ void unquant_bands(const CELTMode *m, celt_norm_t *X, celt_norm_t *P, int total_
 {
    int i, j, B, bits;
    const celt_int16_t *eBands = m->eBands;
-   celt_word16_t alpha;
    VARDECL(celt_norm_t, norm);
    VARDECL(int, pulses);
    VARDECL(int, offsets);
@@ -384,20 +379,17 @@ void unquant_bands(const CELTMode *m, celt_norm_t *X, celt_norm_t *P, int total_
       if (eBands[i] >= m->pitchEnd || q<=0)
       {
          q -= 1;
-         alpha = 0;
          if (q<0)
             intra_fold(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), norm, P+B*eBands[i], B, eBands[i], eBands[m->nbEBands+1]);
          else
             intra_unquant(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), q, norm, P+B*eBands[i], B, eBands[i], dec);
-      } else {
-         alpha = QCONST16(.7f,15);
       }
       
       if (q > 0)
       {
          int nb_rotations = (B*(eBands[i+1]-eBands[i])+4*q)/(8*q);
          exp_rotation(P+B*eBands[i], B*(eBands[i+1]-eBands[i]), -1, B, nb_rotations);
-         alg_unquant(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), q, P+B*eBands[i], alpha, dec);
+         alg_unquant(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), q, P+B*eBands[i], dec);
          exp_rotation(X+B*eBands[i], B*(eBands[i+1]-eBands[i]), 1, B, nb_rotations);
       }
       for (j=B*eBands[i];j<B*eBands[i+1];j++)
index 6aaee71..6e32a34 100644 (file)
@@ -42,7 +42,7 @@
 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
    applies the compression in the pitch direction, computes the gain that will
    give ||p+g*y||=1 and mixes the residual with the pitch. */
-static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
+static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P)
 {
    int i;
    celt_word32_t Ryp, Ryy, Rpp;
@@ -68,10 +68,9 @@ static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const
       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
 
    /* Remove part of the pitch component to compute the real residual from
-      the encoded (int) one */
+   the encoded (int) one */
    for (i=0;i<N;i++)
-      y[i] = SUB16(SHL16(iy[i],yshift),
-                   MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND16(Ryp,14),P[i])));
+      y[i] = SHL16(iy[i],yshift);
 
    /* Recompute after the projection (I think it's right) */
    Ryp = 0;
@@ -105,26 +104,19 @@ struct NBest {
    celt_word32_t yp;
 };
 
-void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
+void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
 {
-   int L = 3;
    VARDECL(celt_norm_t, _y);
    VARDECL(celt_norm_t, _ny);
    VARDECL(int, _iy);
    VARDECL(int, _iny);
-   VARDECL(celt_norm_t *, y);
-   VARDECL(celt_norm_t *, ny);
-   VARDECL(int *, iy);
-   VARDECL(int *, iny);
-   int i, j, k, m;
+   celt_norm_t *y, *ny;
+   int *iy, *iny;
+   int i, j;
    int pulsesLeft;
-   VARDECL(celt_word32_t, xy);
-   VARDECL(celt_word32_t, yy);
-   VARDECL(celt_word32_t, yp);
-   VARDECL(struct NBest, _nbest);
-   VARDECL(struct NBest *, nbest);
+   celt_word32_t xy, yy, yp;
+   struct NBest nbest;
    celt_word32_t Rpp=0, Rxp=0;
-   int maxL = 1;
 #ifdef FIXED_POINT
    int yshift;
 #endif
@@ -134,31 +126,14 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
    yshift = 14-EC_ILOG(K);
 #endif
 
-   ALLOC(_y, L*N, celt_norm_t);
-   ALLOC(_ny, L*N, celt_norm_t);
-   ALLOC(_iy, L*N, int);
-   ALLOC(_iny, L*N, int);
-   ALLOC(y, L, celt_norm_t*);
-   ALLOC(ny, L, celt_norm_t*);
-   ALLOC(iy, L, int*);
-   ALLOC(iny, L, int*);
-   
-   ALLOC(xy, L, celt_word32_t);
-   ALLOC(yy, L, celt_word32_t);
-   ALLOC(yp, L, celt_word32_t);
-   ALLOC(_nbest, L, struct NBest);
-   ALLOC(nbest, L, struct NBest *);
-   
-   for (m=0;m<L;m++)
-      nbest[m] = &_nbest[m];
-   
-   for (m=0;m<L;m++)
-   {
-      ny[m] = &_ny[m*N];
-      iny[m] = &_iny[m*N];
-      y[m] = &_y[m*N];
-      iy[m] = &_iy[m*N];
-   }
+   ALLOC(_y, N, celt_norm_t);
+   ALLOC(_ny, N, celt_norm_t);
+   ALLOC(_iy, N, int);
+   ALLOC(_iny, N, int);
+   y = _y;
+   ny = _ny;
+   iy = _iy;
+   iny = _iny;
    
    for (j=0;j<N;j++)
    {
@@ -167,186 +142,130 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
    }
    Rpp = ROUND16(Rpp, NORM_SHIFT);
    Rxp = ROUND16(Rxp, NORM_SHIFT);
+
    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
 
-   /* We only need to initialise the zero because the first iteration only uses that */
    for (i=0;i<N;i++)
-      y[0][i] = 0;
+      y[i] = 0;
    for (i=0;i<N;i++)
-      iy[0][i] = 0;
-   xy[0] = yy[0] = yp[0] = 0;
+      iy[i] = 0;
+   xy = yy = yp = 0;
 
    pulsesLeft = K;
    while (pulsesLeft > 0)
    {
       int pulsesAtOnce=1;
-      int Lupdate = L;
-      int L2 = L;
       
-      /* Decide on complexity strategy */
+      /* Decide on how many pulses to find at once */
       pulsesAtOnce = pulsesLeft/N;
       if (pulsesAtOnce<1)
          pulsesAtOnce = 1;
-      if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
-         Lupdate = 1;
       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
-      L2 = Lupdate;
-      if (L2>maxL)
-      {
-         L2 = maxL;
-         maxL *= N;
-      }
 
-      for (m=0;m<Lupdate;m++)
-         nbest[m]->score = -VERY_LARGE32;
+      nbest.score = -VERY_LARGE32;
 
-      for (m=0;m<L2;m++)
+      for (j=0;j<N;j++)
       {
-         for (j=0;j<N;j++)
+         int sign;
+         /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
+         celt_word32_t Rxy, Ryy, Ryp;
+         celt_word32_t score;
+         celt_word32_t g;
+         celt_word16_t s;
+         
+         /* Select sign based on X[j] alone */
+         if (X[j]>0) sign=1; else sign=-1;
+         s = SHL16(sign*pulsesAtOnce, yshift);
+
+         /* Updating the sums of the new pulse(s) */
+         Rxy = xy + MULT16_16(s,X[j]);
+         Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
+         Ryp = yp + MULT16_16(s, P[j]);
+         
+         if (pulsesLeft>1)
+         {
+            score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
+         } else
          {
-            int sign;
-            /*if (x[j]>0) sign=1; else sign=-1;*/
-            for (sign=-1;sign<=1;sign+=2)
-            {
-               /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
-               celt_word32_t Rxy, Ryy, Ryp;
-               celt_word16_t spj, aspj; /* Intermediate results */
-               celt_word32_t score;
-               celt_word32_t g;
-               celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
-               
-               /* All pulses at one location must have the same sign. */
-               if (iy[m][j]*sign < 0)
-                  continue;
-
-               spj = MULT16_16_Q14(s, P[j]);
-               aspj = MULT16_16_Q15(alpha, spj);
-               /* Updating the sums of the new pulse(s) */
-               Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_Q15(alpha,spj),Rxp);
-               Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
-               Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
-               
-               /* Compute the gain such that ||p + g*y|| = 1 */
-               g = MULT16_32_Q15(
-                        celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
-                                  MULT16_16(ROUND16(Ryy,14),Rpp))
-                        - ROUND16(Ryp,14),
-                   celt_rcp(SHR32(Ryy,12)));
-               /* Knowing that gain, what's the error: (x-g*y)^2 
-                  (result is negated and we discard x^2 because it's constant) */
-               /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
-               score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
-                       - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
-
-               if (score>nbest[Lupdate-1]->score)
-               {
-                  int id = Lupdate-1;
-                  struct NBest *tmp_best;
-
-                  /* Save some pointers that would be deleted and use them for the current entry*/
-                  tmp_best = nbest[Lupdate-1];
-                  while (id > 0 && score > nbest[id-1]->score)
-                     id--;
-               
-                  for (k=Lupdate-1;k>id;k--)
-                     nbest[k] = nbest[k-1];
-
-                  nbest[id] = tmp_best;
-                  nbest[id]->score = score;
-                  nbest[id]->pos = j;
-                  nbest[id]->orig = m;
-                  nbest[id]->sign = sign;
-                  nbest[id]->xy = Rxy;
-                  nbest[id]->yy = Ryy;
-                  nbest[id]->yp = Ryp;
-               }
-            }
+            /* Compute the gain such that ||p + g*y|| = 1 */
+            g = MULT16_32_Q15(
+                     celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
+                               MULT16_16(ROUND16(Ryy,14),Rpp))
+                     - ROUND16(Ryp,14),
+                celt_rcp(SHR32(Ryy,12)));
+            /* Knowing that gain, what's the error: (x-g*y)^2 
+               (result is negated and we discard x^2 because it's constant) */
+            /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
+            score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
+                    - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
+         }
+         
+         if (score>nbest.score)
+         {
+            nbest.score = score;
+            nbest.pos = j;
+            nbest.orig = 0;
+            nbest.sign = sign;
+            nbest.xy = Rxy;
+            nbest.yy = Ryy;
+            nbest.yp = Ryp;
          }
-
       }
-      
+
       celt_assert2(nbest[0]->score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
+
       /* Only now that we've made the final choice, update ny/iny and others */
-      for (k=0;k<Lupdate;k++)
       {
          int n;
          int is;
          celt_norm_t s;
-         is = nbest[k]->sign*pulsesAtOnce;
+         is = nbest.sign*pulsesAtOnce;
          s = SHL16(is, yshift);
          for (n=0;n<N;n++)
-            ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
-         ny[k][nbest[k]->pos] += s;
+            ny[n] = y[n];
+         ny[nbest.pos] += s;
 
          for (n=0;n<N;n++)
-            iny[k][n] = iy[nbest[k]->orig][n];
-         iny[k][nbest[k]->pos] += is;
+            iny[n] = iy[n];
+         iny[nbest.pos] += is;
 
-         xy[k] = nbest[k]->xy;
-         yy[k] = nbest[k]->yy;
-         yp[k] = nbest[k]->yp;
+         xy = nbest.xy;
+         yy = nbest.yy;
+         yp = nbest.yp;
       }
       /* Swap ny/iny with y/iy */
-      for (k=0;k<Lupdate;k++)
       {
          celt_norm_t *tmp_ny;
          int *tmp_iny;
 
-         tmp_ny = ny[k];
-         ny[k] = y[k];
-         y[k] = tmp_ny;
-         tmp_iny = iny[k];
-         iny[k] = iy[k];
-         iy[k] = tmp_iny;
+         tmp_ny = ny;
+         ny = y;
+         y = tmp_ny;
+         tmp_iny = iny;
+         iny = iy;
+         iy = tmp_iny;
       }
       pulsesLeft -= pulsesAtOnce;
    }
    
-#if 0
-   if (0) {
-      celt_word32_t err=0;
-      for (i=0;i<N;i++)
-         err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
-      /*if (N<=10)
-        printf ("%f %d %d\n", err, K, N);*/
-   }
-   /* Sanity checks, don't bother */
-   if (0) {
-      for (i=0;i<N;i++)
-         x[i] = p[i]+nbest[0]->gain*y[0][i];
-      celt_word32_t E=1e-15;
-      int ABS = 0;
-      for (i=0;i<N;i++)
-         ABS += abs(iy[0][i]);
-      /*if (K != ABS)
-         printf ("%d %d\n", K, ABS);*/
-      for (i=0;i<N;i++)
-         E += x[i]*x[i];
-      /*printf ("%f\n", E);*/
-      E = 1/sqrt(E);
-      for (i=0;i<N;i++)
-         x[i] *= E;
-   }
-#endif
-   
-   encode_pulses(iy[0], N, K, enc);
+   encode_pulses(iy, N, K, enc);
    
    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
-      due to the recursive computation used in quantisation.
-      Not quite sure whether we need that or not */
-   mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
+   due to the recursive computation used in quantisation. */
+   mix_pitch_and_residual(iy, X, N, K, P);
    RESTORE_STACK;
 }
 
+
 /** Decode pulse vector and combine the result with the pitch vector to produce
     the final normalised signal in the current band. */
-void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
+void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
 {
    VARDECL(int, iy);
    SAVE_STACK;
    ALLOC(iy, N, int);
    decode_pulses(iy, N, K, dec);
-   mix_pitch_and_residual(iy, X, N, K, P, alpha);
+   mix_pitch_and_residual(iy, X, N, K, P);
    RESTORE_STACK;
 }
 
index 2f19b37..f5c507d 100644 (file)
@@ -51,7 +51,7 @@
  * @param alpha compression factor to apply in the pitch direction (magic!)
  * @param enc Entropy encoder state
 */
-void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc);
+void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc);
 
 /** Algebraic pulse decoder
  * @param x Decoded normalised spectrum (returned)
@@ -61,7 +61,7 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
  * @param alpha compression factor in the pitch direction (magic!)
  * @param dec Entropy decoder state
  */
-void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec);
+void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec);
 
 /** Intra-frame predictor that matches a section of the current frame (at lower
  * frequencies) to encode the current band.