Reduced useless calls to ncwrs64() by half.
[opus.git] / libcelt / cwrs.c
index 46e3f45..7726740 100644 (file)
 #include <stdlib.h>
 #include "cwrs.h"
 
+static celt_uint64_t next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint64_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
+
+static celt_uint64_t prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint64_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
+
+/* Optional implementation of ncwrs64 using update_ncwrs64(). It's slightly
+   slower than the standard ncwrs64(), but it could still be useful.
+celt_uint64_t ncwrs64_opt(int _n,int _m)
+{
+   int i;
+   celt_uint64_t ret;
+   celt_uint64_t nc[_n+1];
+   for (i=0;i<_n+1;i++)
+      nc[i] = 1;
+   for (i=0;i<_m;i++)
+      update_ncwrs64(nc, _n+1, 0);
+   return nc[_n];
+}*/
+
 /*Returns the numer of ways of choosing _m elements from a set of size _n with
    replacement when a sign bit is needed for each unique element.*/
 #if 0
-static unsigned ncwrs(int _n,int _m){
-  static unsigned c[32][32];
+static celt_uint32_t ncwrs(int _n,int _m){
+  static celt_uint32_t c[32][32];
   if(_n<0||_m<0)return 0;
   if(!c[_n][_m]){
     if(_m<=0)c[_n][_m]=1;
@@ -43,10 +87,10 @@ static unsigned ncwrs(int _n,int _m){
   return c[_n][_m];
 }
 #else
-unsigned ncwrs(int _n,int _m){
-  unsigned ret;
-  unsigned f;
-  unsigned d;
+celt_uint32_t ncwrs(int _n,int _m){
+  celt_uint32_t ret;
+  celt_uint32_t f;
+  celt_uint32_t d;
   int      i;
   if(_n<0||_m<0)return 0;
   if(_m==0)return 1;
@@ -63,6 +107,17 @@ unsigned ncwrs(int _n,int _m){
 }
 #endif
 
+#if 0
+celt_uint64_t ncwrs64(int _n,int _m){
+  static celt_uint64_t c[101][101];
+  if(_n<0||_m<0)return 0;
+  if(!c[_n][_m]){
+    if(_m<=0)c[_n][_m]=1;
+    else if(_n>0)c[_n][_m]=ncwrs64(_n-1,_m)+ncwrs64(_n,_m-1)+ncwrs64(_n-1,_m-1);
+}
+  return c[_n][_m];
+}
+#else
 celt_uint64_t ncwrs64(int _n,int _m){
   celt_uint64_t ret;
   celt_uint64_t f;
@@ -81,18 +136,19 @@ celt_uint64_t ncwrs64(int _n,int _m){
   }
   return ret;
 }
+#endif
 
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
   _x:      Returns the combination with elements sorted in ascending order.
   _s:      Returns the associated sign bits.*/
-void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
+void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
   int j;
   int k;
   for(k=j=0;k<_m;k++){
-    unsigned pn;
-    unsigned p;
-    unsigned t;
+    celt_uint32_t pn;
+    celt_uint32_t p;
+    celt_uint32_t t;
     p=ncwrs(_n-j,_m-k-1);
     pn=ncwrs(_n-j-1,_m-k-1);
     p+=pn;
@@ -118,14 +174,14 @@ void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
    of size _n with associated sign bits.
   _x:      The combination with elements sorted in ascending order.
   _s:      The associated sign bits.*/
-unsigned icwrs(int _n,int _m,const int *_x,const int *_s){
-  unsigned i;
+celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s){
+  celt_uint32_t i;
   int      j;
   int      k;
   i=0;
   for(k=j=0;k<_m;k++){
-    unsigned pn;
-    unsigned p;
+    celt_uint32_t pn;
+    celt_uint32_t p;
     p=ncwrs(_n-j,_m-k-1);
     pn=ncwrs(_n-j-1,_m-k-1);
     p+=pn;
@@ -149,12 +205,19 @@ unsigned icwrs(int _n,int _m,const int *_x,const int *_s){
 void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
   int j;
   int k;
+  celt_uint64_t nc[_n+1];
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m-1;k++)
+    next_ncwrs64(nc, _n+1, 0);
   for(k=j=0;k<_m;k++){
     celt_uint64_t pn;
     celt_uint64_t p;
     celt_uint64_t t;
-    p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);
+    /*p=ncwrs64(_n-j,_m-k-1);
+    pn=ncwrs64(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
     p+=pn;
     if(k>0){
       t=p>>1;
@@ -164,13 +227,18 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
       _i-=p;
       j++;
       p=pn;
-      pn=ncwrs64(_n-j-1,_m-k-1);
+      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
+    if (k<_m-2)
+      prev_ncwrs64(nc, _n+1, 0);
+    else
+      prev_ncwrs64(nc, _n+1, 1);
   }
 }
 
@@ -178,23 +246,37 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
    of size _n with associated sign bits.
   _x:      The combination with elements sorted in ascending order.
   _s:      The associated sign bits.*/
-celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s){
+celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
   celt_uint64_t i;
   int           j;
   int           k;
+  celt_uint64_t nc[_n+1];
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m;k++)
+    next_ncwrs64(nc, _n+1, 0);
+  if (bound)
+     *bound = nc[_n];
   i=0;
   for(k=j=0;k<_m;k++){
     celt_uint64_t pn;
     celt_uint64_t p;
-    p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);
+    if (k<_m-1)
+      prev_ncwrs64(nc, _n+1, 0);
+    else
+      prev_ncwrs64(nc, _n+1, 1);
+    /*p=ncwrs64(_n-j,_m-k-1);
+    pn=ncwrs64(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
     p+=pn;
     if(k>0)p>>=1;
     while(j<_x[k]){
       i+=p;
       j++;
       p=pn;
-      pn=ncwrs64(_n-j-1,_m-k-1);
+      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
@@ -241,95 +323,21 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
   }
 }
 
-#if 0
-#include <stdio.h>
-#define NMAX (10)
-#define MMAX (9)
-
-int main(int _argc,char **_argv){
-  int n;
-  for(n=0;n<=NMAX;n++){
-    int m;
-    for(m=0;m<=MMAX;m++){
-      unsigned nc;
-      unsigned i;
-      nc=ncwrs(n,m);
-      for(i=0;i<nc;i++){
-        int x[MMAX];
-        int s[MMAX];
-        int x2[MMAX];
-        int s2[MMAX];
-        int y[NMAX];
-        int j;
-        int k;
-        cwrsi(n,m,i,x,s);
-        printf("%6u of %u:",i,nc);
-        for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
-        printf(" ->");
-        if(icwrs(n,m,x,s)!=i){
-          fprintf(stderr,"Combination-index mismatch.\n");
-        }
-        comb2pulse(n,m,y,x,s);
-        for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          break;
-        }
-      }
-      printf("\n");
-    }
-  }
-  return -1;
+void encode_pulses(int *_y, int N, int K, ec_enc *enc)
+{
+   int comb[K];
+   int signs[K];
+   pulse2comb(N, K, comb, signs, _y);
+   celt_uint64_t bound, id;
+   id = icwrs64(N, K, comb, signs, &bound);
+   ec_enc_uint64(enc,id,bound);
 }
-#endif
 
-#if 0
-#include <stdio.h>
-#define NMAX (32)
-#define MMAX (16)
-
-int main(int _argc,char **_argv){
-  int n;
-  for(n=0;n<=NMAX;n+=3){
-    int m;
-    for(m=0;m<=MMAX;m++){
-      celt_uint64_t nc;
-      celt_uint64_t i;
-      nc=ncwrs64(n,m);
-      printf("%d/%d: %llu",n,m, nc);
-      for(i=0;i<nc;i+=100000){
-        int x[MMAX];
-        int s[MMAX];
-        int x2[MMAX];
-        int s2[MMAX];
-        int y[NMAX];
-        int j;
-        int k;
-        cwrsi64(n,m,i,x,s);
-        /*printf("%llu of %llu:",i,nc);
-        for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
-        printf(" ->");*/
-        if(icwrs64(n,m,x,s)!=i){
-          fprintf(stderr,"Combination-index mismatch.\n");
-        }
-        comb2pulse(n,m,y,x,s);
-        /*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");*/
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          break;
-        }
-      }
-      printf("\n");
-    }
-  }
-  return 0;
+void decode_pulses(int *_y, int N, int K, ec_dec *dec)
+{
+   int comb[K];
+   int signs[K];   
+   cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
+   comb2pulse(N, K, _y, comb, signs);
 }
-#endif
+