vq search is now moving much less data around
authorJean-Marc Valin <Jean-Marc.Valin@csiro.au>
Thu, 14 Feb 2008 05:46:50 +0000 (16:46 +1100)
committerJean-Marc Valin <Jean-Marc.Valin@csiro.au>
Thu, 14 Feb 2008 05:46:50 +0000 (16:46 +1100)
libcelt/vq.c

index 01d2934..5ae7f54 100644 (file)
 #include "cwrs.h"
 #include "vq.h"
 
+struct NBest {
+   float score;
+   float gain;
+   int sign;
+   int pos;
+   int orig;
+   float xy;
+   float yy;
+   float yp;
+};
 
 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
    a combination of pulses such that its norm is still equal to 1. The only difference with 
@@ -51,15 +61,18 @@ void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *
    int *(iny[L]), *(iy[L]);
    int i, j, m;
    int pulsesLeft;
-   float xy[L], nxy[L];
-   float yy[L], nyy[L];
-   float yp[L], nyp[L];
-   float best_scores[L];
+   float xy[L];
+   float yy[L];
+   float yp[L];
+   struct NBest _nbest[L];
+   struct NBest *(nbest[L]);
    float Rpp=0, Rxp=0;
-   float gain[L];
    int maxL = 1;
    
    for (m=0;m<L;m++)
+      nbest[m] = &_nbest[m];
+   
+   for (m=0;m<L;m++)
    {
       ny[m] = _ny[m];
       iny[m] = _iny[m];
@@ -86,7 +99,7 @@ void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *
          iy[m][i] = iny[m][i] = 0;
 
    for (m=0;m<L;m++)
-      xy[m] = yy[m] = yp[m] = gain[m] = 0;
+      xy[m] = yy[m] = yp[m] = 0;
    
    pulsesLeft = K;
    while (pulsesLeft > 0)
@@ -108,7 +121,7 @@ void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *
       }
 
       for (m=0;m<L;m++)
-         best_scores[m] = -1e10;
+         nbest[m]->score = -1e10;
 
       for (m=0;m<L2;m++)
       {
@@ -136,48 +149,29 @@ void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *
                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
                score = 2*g*tmp_xy - g*g*tmp_yy;
 
-               if (score>best_scores[Lupdate-1])
+               if (score>nbest[Lupdate-1]->score)
                {
                   int k, n;
                   int id = Lupdate-1;
-                  float *tmp_ny;
-                  int *tmp_iny;
-                  
-                  tmp_ny = ny[Lupdate-1];
-                  tmp_iny = iny[Lupdate-1];
-                  while (id > 0 && score > best_scores[id-1])
+                  struct NBest *tmp_best;
+
+                  /* Save some pointers that would be deleted and use them for the current entry*/
+                  tmp_best = nbest[Lupdate-1];
+                  while (id > 0 && score > nbest[id-1]->score)
                      id--;
                
                   for (k=Lupdate-1;k>id;k--)
-                  {
-                     nxy[k] = nxy[k-1];
-                     nyy[k] = nyy[k-1];
-                     nyp[k] = nyp[k-1];
-                     //fprintf(stderr, "%d %d \n", N, k);
-                     ny[k] = ny[k-1];
-                     iny[k] = iny[k-1];
-                     gain[k] = gain[k-1];
-                     best_scores[k] = best_scores[k-1];
-                  }
-
-                  ny[id] = tmp_ny;
-                  iny[id] = tmp_iny;
-
-                  nxy[id] = tmp_xy;
-                  nyy[id] = tmp_yy;
-                  nyp[id] = tmp_yp;
-                  gain[id] = g;
-                  for (n=0;n<N;n++)
-                     ny[id][n] = y[m][n] - alpha*s*p[j]*p[n];
-                  ny[id][j] += s;
-
-                  for (n=0;n<N;n++)
-                     iny[id][n] = iy[m][n];
-                  if (s>0)
-                     iny[id][j] += pulsesAtOnce;
-                  else
-                     iny[id][j] -= pulsesAtOnce;
-                  best_scores[id] = score;
+                     nbest[k] = nbest[k-1];
+
+                  nbest[id] = tmp_best;
+                  nbest[id]->score = score;
+                  nbest[id]->pos = j;
+                  nbest[id]->orig = m;
+                  nbest[id]->sign = sign;
+                  nbest[id]->gain = g;
+                  nbest[id]->xy = tmp_xy;
+                  nbest[id]->yy = tmp_yy;
+                  nbest[id]->yp = tmp_yp;
                }
             }
          }
@@ -186,13 +180,29 @@ void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *
       int k;
       for (k=0;k<Lupdate;k++)
       {
+         int n;
+         int is;
+         float s;
+         is = nbest[k]->sign*pulsesAtOnce;
+         s = is;
+         for (n=0;n<N;n++)
+            ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
+         ny[k][nbest[k]->pos] += s;
+      
+         for (n=0;n<N;n++)
+            iny[k][n] = iy[nbest[k]->orig][n];
+         iny[k][nbest[k]->pos] += is;
+         
+         xy[k] = nbest[k]->xy;
+         yy[k] = nbest[k]->yy;
+         yp[k] = nbest[k]->yp;
+      }
+         
+      for (k=0;k<Lupdate;k++)
+      {
          float *tmp_ny;
          int *tmp_iny;
 
-         xy[k] = nxy[k];
-         yy[k] = nyy[k];
-         yp[k] = nyp[k];
-         
          tmp_ny = ny[k];
          ny[k] = y[k];
          y[k] = tmp_ny;
@@ -206,12 +216,12 @@ void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *
    if (0) {
       float err=0;
       for (i=0;i<N;i++)
-         err += (x[i]-gain[0]*y[0][i])*(x[i]-gain[0]*y[0][i]);
+         err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
       //if (N<=10)
       //printf ("%f %d %d\n", err, K, N);
    }
    for (i=0;i<N;i++)
-      x[i] = p[i]+gain[0]*y[0][i];
+      x[i] = p[i]+nbest[0]->gain*y[0][i];
    if (0) {
       float E=1e-15;
       int ABS = 0;