fixed-point: unquant_energy_mono() has received the fixed-point code from
[opus.git] / libcelt / cwrs.c
index 46e3f45..5d72707 100644 (file)
@@ -1,4 +1,5 @@
-/* (C) 2007 Timothy B. Terriberry */
+/* (C) 2007 Timothy B. Terriberry
+   (C) 2008 Jean-Marc Valin */
 /*
    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
+
+/* Functions for encoding and decoding pulse vectors. For more details, see:
+   http://people.xiph.org/~tterribe/notes/cwrs.html
+*/
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
 #include <stdlib.h>
 #include "cwrs.h"
 
+/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
+   compute ncwrs() for m+1, for all n. Could also be used when m and n are
+   swapped just by changing nc */
+static void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint32_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
+
+/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
+   compute ncwrs() for m-1, for all n. Could also be used when m and n are
+   swapped just by changing nc */
+static void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint32_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
+
+static void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint64_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
+
+static void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint64_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
+
 /*Returns the numer of ways of choosing _m elements from a set of size _n with
    replacement when a sign bit is needed for each unique element.*/
-#if 0
-static unsigned ncwrs(int _n,int _m){
-  static unsigned c[32][32];
-  if(_n<0||_m<0)return 0;
-  if(!c[_n][_m]){
-    if(_m<=0)c[_n][_m]=1;
-    else if(_n>0)c[_n][_m]=ncwrs(_n-1,_m)+ncwrs(_n,_m-1)+ncwrs(_n-1,_m-1);
-  }
-  return c[_n][_m];
+celt_uint32_t ncwrs(int _n,int _m)
+{
+   int i;
+   VARDECL(celt_uint32_t *nc);
+   ALLOC(nc,_n+1, celt_uint32_t);
+   for (i=0;i<_n+1;i++)
+      nc[i] = 1;
+   for (i=0;i<_m;i++)
+      next_ncwrs32(nc, _n+1, 0);
+   return nc[_n];
 }
-#else
-unsigned ncwrs(int _n,int _m){
-  unsigned ret;
-  unsigned f;
-  unsigned d;
-  int      i;
-  if(_n<0||_m<0)return 0;
-  if(_m==0)return 1;
-  if(_n==0)return 0;
-  ret=0;
-  f=_n;
-  d=1;
-  for(i=1;i<=_m;i++){
-    ret+=f*d<<i;
-    f=(f*(_n-i))/(i+1);
-    d=(d*(_m-i))/i;
-  }
-  return ret;
-}
-#endif
 
-celt_uint64_t ncwrs64(int _n,int _m){
-  celt_uint64_t ret;
-  celt_uint64_t f;
-  celt_uint64_t d;
-  int           i;
-  if(_n<0||_m<0)return 0;
-  if(_m==0)return 1;
-  if(_n==0)return 0;
-  ret=0;
-  f=_n;
-  d=1;
-  for(i=1;i<=_m;i++){
-    ret+=f*d<<i;
-    f=(f*(_n-i))/(i+1);
-    d=(d*(_m-i))/i;
-  }
-  return ret;
+/*Returns the numer of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.*/
+celt_uint64_t ncwrs64(int _n,int _m)
+{
+   int i;
+   VARDECL(celt_uint64_t *nc);
+   ALLOC(nc,_n+1, celt_uint64_t);
+   for (i=0;i<_n+1;i++)
+      nc[i] = 1;
+   for (i=0;i<_m;i++)
+      next_ncwrs64(nc, _n+1, 0);
+   return nc[_n];
 }
 
+
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
   _x:      Returns the combination with elements sorted in ascending order.
   _s:      Returns the associated sign bits.*/
-void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
+void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
   int j;
   int k;
+  VARDECL(celt_uint32_t *nc);
+  ALLOC(nc,_n+1, celt_uint32_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m-1;k++)
+    next_ncwrs32(nc, _n+1, 0);
   for(k=j=0;k<_m;k++){
-    unsigned pn;
-    unsigned p;
-    unsigned t;
-    p=ncwrs(_n-j,_m-k-1);
-    pn=ncwrs(_n-j-1,_m-k-1);
+    celt_uint32_t pn, p, t;
+    /*p=ncwrs(_n-j,_m-k-1);
+    pn=ncwrs(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
     p+=pn;
     if(k>0){
       t=p>>1;
@@ -104,13 +163,18 @@ void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
       _i-=p;
       j++;
       p=pn;
-      pn=ncwrs(_n-j-1,_m-k-1);
+      /*pn=ncwrs(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
+    if (k<_m-2)
+      prev_ncwrs32(nc, _n+1, 0);
+    else
+      prev_ncwrs32(nc, _n+1, 1);
   }
 }
 
@@ -118,23 +182,38 @@ void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
    of size _n with associated sign bits.
   _x:      The combination with elements sorted in ascending order.
   _s:      The associated sign bits.*/
-unsigned icwrs(int _n,int _m,const int *_x,const int *_s){
-  unsigned i;
+celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
+  celt_uint32_t i;
   int      j;
   int      k;
+  VARDECL(celt_uint32_t *nc);
+  ALLOC(nc,_n+1, celt_uint32_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m;k++)
+    next_ncwrs32(nc, _n+1, 0);
+  if (bound)
+    *bound = nc[_n];
   i=0;
   for(k=j=0;k<_m;k++){
-    unsigned pn;
-    unsigned p;
-    p=ncwrs(_n-j,_m-k-1);
-    pn=ncwrs(_n-j-1,_m-k-1);
+    celt_uint32_t pn;
+    celt_uint32_t p;
+    if (k<_m-1)
+      prev_ncwrs32(nc, _n+1, 0);
+    else
+      prev_ncwrs32(nc, _n+1, 1);
+    /*p=ncwrs(_n-j,_m-k-1);
+    pn=ncwrs(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
     p+=pn;
     if(k>0)p>>=1;
     while(j<_x[k]){
       i+=p;
       j++;
       p=pn;
-      pn=ncwrs(_n-j-1,_m-k-1);
+      /*pn=ncwrs(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
@@ -149,12 +228,18 @@ unsigned icwrs(int _n,int _m,const int *_x,const int *_s){
 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
   int j;
   int k;
+  VARDECL(celt_uint64_t *nc);
+  ALLOC(nc,_n+1, celt_uint64_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m-1;k++)
+    next_ncwrs64(nc, _n+1, 0);
   for(k=j=0;k<_m;k++){
-    celt_uint64_t pn;
-    celt_uint64_t p;
-    celt_uint64_t t;
-    p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);
+    celt_uint64_t pn, p, t;
+    /*p=ncwrs64(_n-j,_m-k-1);
+    pn=ncwrs64(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
     p+=pn;
     if(k>0){
       t=p>>1;
@@ -164,13 +249,18 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
       _i-=p;
       j++;
       p=pn;
-      pn=ncwrs64(_n-j-1,_m-k-1);
+      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
+    if (k<_m-2)
+      prev_ncwrs64(nc, _n+1, 0);
+    else
+      prev_ncwrs64(nc, _n+1, 1);
   }
 }
 
@@ -178,23 +268,38 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
    of size _n with associated sign bits.
   _x:      The combination with elements sorted in ascending order.
   _s:      The associated sign bits.*/
-celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s){
+celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
   celt_uint64_t i;
   int           j;
   int           k;
+  VARDECL(celt_uint64_t *nc);
+  ALLOC(nc,_n+1, celt_uint64_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m;k++)
+    next_ncwrs64(nc, _n+1, 0);
+  if (bound)
+     *bound = nc[_n];
   i=0;
   for(k=j=0;k<_m;k++){
     celt_uint64_t pn;
     celt_uint64_t p;
-    p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);
+    if (k<_m-1)
+      prev_ncwrs64(nc, _n+1, 0);
+    else
+      prev_ncwrs64(nc, _n+1, 1);
+    /*p=ncwrs64(_n-j,_m-k-1);
+    pn=ncwrs64(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
     p+=pn;
     if(k>0)p>>=1;
     while(j<_x[k]){
       i+=p;
       j++;
       p=pn;
-      pn=ncwrs64(_n-j-1,_m-k-1);
+      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
@@ -241,95 +346,42 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
   }
 }
 
-#if 0
-#include <stdio.h>
-#define NMAX (10)
-#define MMAX (9)
-
-int main(int _argc,char **_argv){
-  int n;
-  for(n=0;n<=NMAX;n++){
-    int m;
-    for(m=0;m<=MMAX;m++){
-      unsigned nc;
-      unsigned i;
-      nc=ncwrs(n,m);
-      for(i=0;i<nc;i++){
-        int x[MMAX];
-        int s[MMAX];
-        int x2[MMAX];
-        int s2[MMAX];
-        int y[NMAX];
-        int j;
-        int k;
-        cwrsi(n,m,i,x,s);
-        printf("%6u of %u:",i,nc);
-        for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
-        printf(" ->");
-        if(icwrs(n,m,x,s)!=i){
-          fprintf(stderr,"Combination-index mismatch.\n");
-        }
-        comb2pulse(n,m,y,x,s);
-        for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          break;
-        }
-      }
-      printf("\n");
-    }
-  }
-  return -1;
+void encode_pulses(int *_y, int N, int K, ec_enc *enc)
+{
+   VARDECL(int *comb);
+   VARDECL(int *signs);
+   
+   ALLOC(comb, K, int);
+   ALLOC(signs, K, int);
+   
+   pulse2comb(N, K, comb, signs, _y);
+   /* Go with 32-bit path if we're sure we can */
+   if (N<=13 && K<=13)
+   {
+      celt_uint32_t bound, id;
+      id = icwrs(N, K, comb, signs, &bound);
+      ec_enc_uint(enc,id,bound);
+   } else {
+      celt_uint64_t bound, id;
+      id = icwrs64(N, K, comb, signs, &bound);
+      ec_enc_uint64(enc,id,bound);
+   }
 }
-#endif
-
-#if 0
-#include <stdio.h>
-#define NMAX (32)
-#define MMAX (16)
 
-int main(int _argc,char **_argv){
-  int n;
-  for(n=0;n<=NMAX;n+=3){
-    int m;
-    for(m=0;m<=MMAX;m++){
-      celt_uint64_t nc;
-      celt_uint64_t i;
-      nc=ncwrs64(n,m);
-      printf("%d/%d: %llu",n,m, nc);
-      for(i=0;i<nc;i+=100000){
-        int x[MMAX];
-        int s[MMAX];
-        int x2[MMAX];
-        int s2[MMAX];
-        int y[NMAX];
-        int j;
-        int k;
-        cwrsi64(n,m,i,x,s);
-        /*printf("%llu of %llu:",i,nc);
-        for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
-        printf(" ->");*/
-        if(icwrs64(n,m,x,s)!=i){
-          fprintf(stderr,"Combination-index mismatch.\n");
-        }
-        comb2pulse(n,m,y,x,s);
-        /*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");*/
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          break;
-        }
-      }
-      printf("\n");
-    }
-  }
-  return 0;
+void decode_pulses(int *_y, int N, int K, ec_dec *dec)
+{
+   VARDECL(int *comb);
+   VARDECL(int *signs);
+   
+   ALLOC(comb, K, int);
+   ALLOC(signs, K, int);
+   if (N<=13 && K<=13)
+   {
+      cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
+      comb2pulse(N, K, _y, comb, signs);
+   } else {
+      cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
+      comb2pulse(N, K, _y, comb, signs);
+   }
 }
-#endif
+