Fixed rsqrt testcase for float
[opus.git] / libcelt / cwrs.c
index cc743ad..60880c6 100644 (file)
@@ -1,4 +1,5 @@
-/* (C) 2007 Timothy B. Terriberry */
+/* (C) 2007-2008 Timothy B. Terriberry
+   (C) 2008 Jean-Marc Valin */
 /*
    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
+
+/* Functions for encoding and decoding pulse vectors.
+   These are based on the function
+     U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
+     U(n,1) = U(1,m) = 2,
+    which counts the number of ways of placing m pulses in n dimensions, where
+     at least one pulse lies in dimension 0.
+   For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
+*/
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
 #include <stdlib.h>
+#include <string.h>
 #include "cwrs.h"
+#include "mathops.h"
 
-/*Returns the numer of ways of choosing _m elements from a set of size _n with
-   replacement when a sign bit is needed for each unique element.*/
-#if 0
-static celt_uint32_t ncwrs(int _n,int _m){
-  static celt_uint32_t c[32][32];
-  if(_n<0||_m<0)return 0;
-  if(!c[_n][_m]){
-    if(_m<=0)c[_n][_m]=1;
-    else if(_n>0)c[_n][_m]=ncwrs(_n-1,_m)+ncwrs(_n,_m-1)+ncwrs(_n-1,_m-1);
+/*Computes the next row/column of any recurrence that obeys the relation
+   u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
+  _ui0 is the base case for the new row/column.*/
+static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
+  celt_uint32_t ui1;
+  int           j;
+  for(j=1;j<_len;j++){
+    ui1=_ui[j]+_ui[j-1]+_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
   }
-  return c[_n][_m];
+  _ui[j-1]=_ui0;
 }
-#else
-celt_uint32_t ncwrs(int _n,int _m){
-  celt_uint32_t ret;
-  celt_uint32_t f;
-  celt_uint32_t d;
-  int      i;
-  if(_n<0||_m<0)return 0;
-  if(_m==0)return 1;
-  if(_n==0)return 0;
-  ret=0;
-  f=_n;
-  d=1;
-  for(i=1;i<=_m;i++){
-    ret+=f*d<<i;
-    f=(f*(_n-i))/(i+1);
-    d=(d*(_m-i))/i;
+
+static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
+  celt_uint64_t ui1;
+  int           j;
+  for(j=1;j<_len;j++){
+    ui1=_ui[j]+_ui[j-1]+_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
   }
-  return ret;
+  _ui[j-1]=_ui0;
 }
-#endif
 
-#if 0
-celt_uint64_t ncwrs64(int _n,int _m){
-  static celt_uint64_t c[100][100];
-  if(_n<0||_m<0)return 0;
-  if(!c[_n][_m]){
-    if(_m<=0)c[_n][_m]=1;
-    else if(_n>0)c[_n][_m]=ncwrs64(_n-1,_m)+ncwrs64(_n,_m-1)+ncwrs64(_n-1,_m-1);
+/*Computes the previous row/column of any recurrence that obeys the relation
+   u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
+  _ui0 is the base case for the new row/column.*/
+static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
+  celt_uint32_t ui1;
+  int           j;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]-_ui[j-1]-_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
+  }
+  _ui[j-1]=_ui0;
+}
+
+static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
+  celt_uint64_t ui1;
+  int           j;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]-_ui[j-1]-_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
+  }
+  _ui[j-1]=_ui0;
 }
-  return c[_n][_m];
+
+/*Returns the number of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.
+  On input, _u should be initialized to column (_m-1) of U(n,m).
+  On exit, _u will be initialized to column _m of U(n,m).*/
+celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
+  celt_uint32_t ret;
+  celt_uint32_t ui0;
+  celt_uint32_t ui1;
+  int           j;
+  ret=ui0=2;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]+_ui[j-1]+ui0;
+    _ui[j-1]=ui0;
+    ui0=ui1;
+    ret+=ui0;
+  }
+  _ui[j-1]=ui0;
+  return ret;
 }
-#else
-celt_uint64_t ncwrs64(int _n,int _m){
+
+celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
   celt_uint64_t ret;
-  celt_uint64_t f;
-  celt_uint64_t d;
-  int           i;
-  if(_n<0||_m<0)return 0;
-  if(_m==0)return 1;
-  if(_n==0)return 0;
-  ret=0;
-  f=_n;
-  d=1;
-  for(i=1;i<=_m;i++){
-    ret+=f*d<<i;
-    f=(f*(_n-i))/(i+1);
-    d=(d*(_m-i))/i;
+  celt_uint64_t ui0;
+  celt_uint64_t ui1;
+  int           j;
+  ret=ui0=2;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]+_ui[j-1]+ui0;
+    _ui[j-1]=ui0;
+    ui0=ui1;
+    ret+=ui0;
   }
+  _ui[j-1]=ui0;
   return ret;
 }
-#endif
+
+/*Returns the number of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.
+  On exit, _u will be initialized to column _m of U(n,m).*/
+celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
+  int k;
+  memset(_u,0,_n*sizeof(*_u));
+  if(_m<=0)return 1;
+  if(_n<=0)return 0;
+  for(k=1;k<_m;k++)unext32(_u,_n,2);
+  return ncwrs_unext32(_n,_u);
+}
+
+celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
+  int k;
+  memset(_u,0,_n*sizeof(*_u));
+  if(_m<=0)return 1;
+  if(_n<=0)return 0;
+  for(k=1;k<_m;k++)unext64(_u,_n,2);
+  return ncwrs_unext64(_n,_u);
+}
 
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
-  _x:      Returns the combination with elements sorted in ascending order.
-  _s:      Returns the associated sign bits.*/
-void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
+  _x: Returns the combination with elements sorted in ascending order.
+  _s: Returns the associated sign bits.
+  _u: Temporary storage already initialized to column _m of U(n,m).
+      Its contents will be overwritten.*/
+void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
   int j;
   int k;
   for(k=j=0;k<_m;k++){
-    celt_uint32_t pn;
     celt_uint32_t p;
     celt_uint32_t t;
-    p=ncwrs(_n-j,_m-k-1);
-    pn=ncwrs(_n-j-1,_m-k-1);
-    p+=pn;
+    p=_u[_n-j-1];
     if(k>0){
       t=p>>1;
       if(t<=_i||_s[k-1])_i+=t;
@@ -115,59 +174,23 @@ void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
     while(p<=_i){
       _i-=p;
       j++;
-      p=pn;
-      pn=ncwrs(_n-j-1,_m-k-1);
-      p+=pn;
+      p=_u[_n-j-1];
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
+    uprev32(_u,_n-j,2);
   }
 }
 
-/*Returns the index of the given combination of _m elements chosen from a set
-   of size _n with associated sign bits.
-  _x:      The combination with elements sorted in ascending order.
-  _s:      The associated sign bits.*/
-celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s){
-  celt_uint32_t i;
-  int      j;
-  int      k;
-  i=0;
-  for(k=j=0;k<_m;k++){
-    celt_uint32_t pn;
-    celt_uint32_t p;
-    p=ncwrs(_n-j,_m-k-1);
-    pn=ncwrs(_n-j-1,_m-k-1);
-    p+=pn;
-    if(k>0)p>>=1;
-    while(j<_x[k]){
-      i+=p;
-      j++;
-      p=pn;
-      pn=ncwrs(_n-j-1,_m-k-1);
-      p+=pn;
-    }
-    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
-  }
-  return i;
-}
-
-/*Returns the _i'th combination of _m elements chosen from a set of size _n
-   with associated sign bits.
-  _x:      Returns the combination with elements sorted in ascending order.
-  _s:      Returns the associated sign bits.*/
-void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
+void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
   int j;
   int k;
   for(k=j=0;k<_m;k++){
-    celt_uint64_t pn;
     celt_uint64_t p;
     celt_uint64_t t;
-    p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);
-    p+=pn;
+    p=_u[_n-j-1];
     if(k>0){
       t=p>>1;
       if(t<=_i||_s[k-1])_i+=t;
@@ -175,41 +198,60 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
     while(p<=_i){
       _i-=p;
       j++;
-      p=pn;
-      pn=ncwrs64(_n-j-1,_m-k-1);
-      p+=pn;
+      p=_u[_n-j-1];
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
+    uprev64(_u,_n-j,2);
   }
 }
 
 /*Returns the index of the given combination of _m elements chosen from a set
    of size _n with associated sign bits.
-  _x:      The combination with elements sorted in ascending order.
-  _s:      The associated sign bits.*/
-celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s){
+  _x: The combination with elements sorted in ascending order.
+  _s: The associated sign bits.
+  _u: Temporary storage already initialized to column _m of U(n,m).
+      Its contents will be overwritten.*/
+celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
+ celt_uint32_t *_u){
+  celt_uint32_t i;
+  int           j;
+  int           k;
+  i=0;
+  for(k=j=0;k<_m;k++){
+    celt_uint32_t p;
+    p=_u[_n-j-1];
+    if(k>0)p>>=1;
+    while(j<_x[k]){
+      i+=p;
+      j++;
+      p=_u[_n-j-1];
+    }
+    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
+    uprev32(_u,_n-j,2);
+  }
+  return i;
+}
+
+celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
+ celt_uint64_t *_u){
   celt_uint64_t i;
   int           j;
   int           k;
   i=0;
   for(k=j=0;k<_m;k++){
-    celt_uint64_t pn;
     celt_uint64_t p;
-    p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);
-    p+=pn;
+    p=_u[_n-j-1];
     if(k>0)p>>=1;
     while(j<_x[k]){
       i+=p;
       j++;
-      p=pn;
-      pn=ncwrs64(_n-j-1,_m-k-1);
-      p+=pn;
+      p=_u[_n-j-1];
     }
     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
+    uprev64(_u,_n-j,2);
   }
   return i;
 }
@@ -253,3 +295,83 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
   }
 }
 
+static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
+ ec_enc *_enc){
+  VARDECL(celt_uint32_t,u);
+  celt_uint32_t nc;
+  celt_uint32_t i;
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint32_t);
+  nc=ncwrs_u32(_n,_m,u);
+  i=icwrs32(_n,_m,_x,_s,u);
+  ec_enc_uint(_enc,i,nc);
+  RESTORE_STACK;
+}
+
+static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
+ ec_enc *_enc){
+  VARDECL(celt_uint64_t,u);
+  celt_uint64_t nc;
+  celt_uint64_t i;
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint64_t);
+  nc=ncwrs_u64(_n,_m,u);
+  i=icwrs64(_n,_m,_x,_s,u);
+  ec_enc_uint64(_enc,i,nc);
+  RESTORE_STACK;
+}
+
+void encode_pulses(int *_y, int N, int K, ec_enc *enc)
+{
+   VARDECL(int, comb);
+   VARDECL(int, signs);
+   SAVE_STACK;
+
+   ALLOC(comb, K, int);
+   ALLOC(signs, K, int);
+
+   pulse2comb(N, K, comb, signs, _y);
+   /* Simple heuristic to figure out whether it fits in 32 bits */
+   if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
+   {
+      encode_comb32(N, K, comb, signs, enc);
+   } else {
+      encode_comb64(N, K, comb, signs, enc);
+   }
+   RESTORE_STACK;
+}
+
+static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
+  VARDECL(celt_uint32_t,u);
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint32_t);
+  cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
+  RESTORE_STACK;
+}
+
+static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
+  VARDECL(celt_uint64_t,u);
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint64_t);
+  cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
+  RESTORE_STACK;
+}
+
+void decode_pulses(int *_y, int N, int K, ec_dec *dec)
+{
+   VARDECL(int, comb);
+   VARDECL(int, signs);
+   SAVE_STACK;
+
+   ALLOC(comb, K, int);
+   ALLOC(signs, K, int);
+   /* Simple heuristic to figure out whether it fits in 32 bits */
+   if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
+   {
+      decode_comb32(N, K, comb, signs, dec);
+   } else {
+      decode_comb64(N, K, comb, signs, dec);
+   }
+   comb2pulse(N, K, _y, comb, signs);
+   RESTORE_STACK;
+}