Working allocation interpolation code
authorJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Wed, 16 Jan 2008 11:04:17 +0000 (22:04 +1100)
committerJean-Marc Valin <jean-marc.valin@usherbrooke.ca>
Wed, 16 Jan 2008 11:04:17 +0000 (22:04 +1100)
libcelt/rate.c

index 6d98dd8..f22ab44 100644 (file)
 #include "entcode.h"
 
 #define BITRES 4
+#define BITROUND 8
 #define BITOVERFLOW 1000
 
+#define MAX_PULSES 64
+
 int log2_frac(ec_uint32 val, int frac)
 {
    int i;
@@ -88,7 +91,7 @@ int log2_frac64(ec_uint64 val, int frac)
    return L;
 }
 
-int bits2pulses(int bits, int N)
+int bits2pulses0(int bits, int N)
 {
    int i, b, prev;
    /* FIXME: This is terribly inefficient. Do a bisection instead
@@ -110,6 +113,7 @@ int bits2pulses(int bits, int N)
 
 struct alloc_data {
    int len;
+   const int *bands;
    int **bits;
    int **rev_bits;
 };
@@ -118,9 +122,12 @@ void alloc_init(struct alloc_data *alloc, const CELTMode *m)
 {
    int i, prevN, BC;
    const int *eBands = m->eBands;
+   
+   alloc->len = m->nbEBands;
+   alloc->bands = m->eBands;
    alloc->bits = celt_alloc(m->nbEBands*sizeof(int*));
    alloc->rev_bits = celt_alloc(m->nbEBands*sizeof(int*));
-   alloc->len = m->nbEBands;
+   
    BC = m->nbMdctBlocks*m->nbChannels;
    prevN = -1;
    for (i=0;i<alloc->len;i++)
@@ -133,19 +140,19 @@ void alloc_init(struct alloc_data *alloc, const CELTMode *m)
       } else {
          int j;
          /* FIXME: We could save memory here */
-         alloc->bits[i] = celt_alloc(64*sizeof(int));
-         alloc->rev_bits[i] = celt_alloc(64*sizeof(int));
-         for (j=0;j<64;j++)
+         alloc->bits[i] = celt_alloc(MAX_PULSES*sizeof(int));
+         alloc->rev_bits[i] = celt_alloc(MAX_PULSES*sizeof(int));
+         for (j=0;j<MAX_PULSES;j++)
          {
             alloc->bits[i][j] = log2_frac64(ncwrs64(N, j),BITRES);
             /* We could just update rev_bits here */
-            if (alloc->bits[i][j] > (60>>BITRES))
+            if (alloc->bits[i][j] > (60<<BITRES))
                break;
          }
-         for (;j<64;j++)
+         for (;j<MAX_PULSES;j++)
             alloc->bits[i][j] = BITOVERFLOW;
          for (j=0;j<32;j++)
-            alloc->rev_bits[i][j] = bits2pulses(j, N);
+            alloc->rev_bits[i][j] = bits2pulses0(j, N);
          prevN = N;
       }
    }
@@ -191,31 +198,52 @@ int compute_allocation(const CELTMode *m, int *pulses)
    return (bits+255)>>8;
 }
 
+int bits2pulses(const struct alloc_data *alloc, int band, int bits)
+{
+   int lo, hi;
+   lo = 0;
+   hi = MAX_PULSES;
+   
+   while (hi-lo != 1)
+   {
+      int mid = (lo+hi)>>1;
+      if (alloc->bits[band][mid] >= bits)
+         hi = mid;
+      else
+         lo = mid;
+   }
+   if (bits-alloc->bits[band][lo] <= alloc->bits[band][hi]-bits)
+      return lo;
+   else
+      return hi;
+}
 
-int vec_bits2pulses(int *bands, int *bits, int *pulses, int len, int B)
+int vec_bits2pulses(const struct alloc_data *alloc, const int *bands, int *bits, int *pulses, int len, int B)
 {
    int i;
    int sum=0;
    for (i=0;i<len;i++)
    {
       int N = (bands[i+1]-bands[i])*B;
-      pulses[i] = bits2pulses(bits[i], N);
-      sum += log2_frac64(ncwrs(N, pulses[i]),8);
+      pulses[i] = bits2pulses(alloc, i, bits[i]);
+      sum += alloc->bits[i][pulses[i]];
    }
-   return (sum+255)>>8;
+   return sum;
 }
 
-int interp_bits2pulses(int *bands, int *bits1, int *bits2, int total, int *pulses, int len, int B)
+#if 0
+int interp_bits2pulses(const struct alloc_data *alloc, int *bits1, int *bits2, int total, int *pulses, int len, int B)
 {
    int i;
+   const int *bands = alloc->bands;
    /* FIXME: This too is terribly inefficient. We should do a bisection instead */
    for (i=0;i<16;i++)
    {
       int j;
       int bits[len];
       for (j=0;j<len;j++)
-         bits[j] = ((16-i)*bits1[j] + i*bits2[j]) >> 4;
-      if (vec_bits2pulses(bands, bits, pulses, len, B) > total)
+         bits[j] = ((16-i)*bits1[j] + i*bits2[j]);
+      if (vec_bits2pulses(alloc, bands, bits, pulses, len, B) > total)
          break;
    }
    if (i==0)
@@ -226,10 +254,56 @@ int interp_bits2pulses(int *bands, int *bits1, int *bits2, int total, int *pulse
       /* Get the previous one (that didn't bust). Should rewrite that anyway */
       i--;
       for (j=0;j<len;j++)
-         bits[j] = ((16-i)*bits1[j] + i*bits2[j]) >> 4;      
-      return vec_bits2pulses(bands, bits, pulses, len, B);
+         bits[j] = ((16-i)*bits1[j] + i*bits2[j]);
+      return vec_bits2pulses(alloc, bands, bits, pulses, len, B);
+   }
+}
+#else
+int interp_bits2pulses(const struct alloc_data *alloc, int *bits1, int *bits2, int total, int *pulses, int len, int B)
+{
+   int lo, hi, out;
+   int j;
+   int bits[len];
+   int used_bits[len];
+   const int *bands = alloc->bands;
+   lo = 0;
+   hi = 1<<BITRES;
+   while (hi-lo != 1)
+   {
+      int mid = (lo+hi)>>1;
+      for (j=0;j<len;j++)
+         bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
+      if (vec_bits2pulses(alloc, bands, bits, pulses, len, B) > total<<BITRES)
+         hi = mid;
+      else
+         lo = mid;
    }
+   
+   for (j=0;j<len;j++)
+      bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
+   out = vec_bits2pulses(alloc, bands, bits, pulses, len, B);
+   /* Do some refinement to use up all bits */
+   while(1)
+   {
+      int incremented = 0;
+      for (j=0;j<len;j++)
+      {
+         if (alloc->bits[j][pulses[j]] < bits[j])
+         {
+            if (out+alloc->bits[j][pulses[j]+1]-alloc->bits[j][pulses[j]] <= total<<BITRES)
+            {
+               out = out+alloc->bits[j][pulses[j]+1]-alloc->bits[j][pulses[j]];
+               pulses[j] += 1;
+               incremented = 1;
+            }
+         }
+      }
+      if (!incremented)
+         break;
+   }
+   return (out+BITROUND) >> BITRES;
 }
+#endif
 
 #if 0
 int main()
@@ -262,12 +336,12 @@ int main()
    struct alloc_data alloc;
    
    alloc_init(&alloc, celt_mode0);
-   int b = vec_bits2pulses(bank, bits, pulses, 18, 1);
+   int b = vec_bits2pulses(&alloc, bank, bits, pulses, 18, 1);
    printf ("total: %d bits\n", b);
    for (i=0;i<18;i++)
       printf ("%d ", pulses[i]);
    printf ("\n");
-   b = interp_bits2pulses(bank, bits1, bits2, 160, pulses, 18, 1);
+   b = interp_bits2pulses(&alloc, bits1, bits2, 162, pulses, 18, 1);
    printf ("total: %d bits\n", b);
    for (i=0;i<18;i++)
       printf ("%d ", pulses[i]);