Fixed rsqrt testcase for float
[opus.git] / libcelt / cwrs.c
index 5184794..60880c6 100644 (file)
@@ -1,4 +1,4 @@
-/* (C) 2007 Timothy B. Terriberry
+/* (C) 2007-2008 Timothy B. Terriberry
    (C) 2008 Jean-Marc Valin */
 /*
    Redistribution and use in source and binary forms, with or without
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
 
-/* Functions for encoding and decoding pulse vectors. For more details, see:
-   http://people.xiph.org/~tterribe/notes/cwrs.html
+/* Functions for encoding and decoding pulse vectors.
+   These are based on the function
+     U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1),
+     U(n,1) = U(1,m) = 2,
+    which counts the number of ways of placing m pulses in n dimensions, where
+     at least one pulse lies in dimension 0.
+   For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html
 */
 
 #ifdef HAVE_CONFIG_H
 #endif
 
 #include <stdlib.h>
+#include <string.h>
 #include "cwrs.h"
+#include "mathops.h"
 
-/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
-   compute ncwrs() for m+1, for all n. Could also be used when m and n are
-   swapped just by changing nc */
-static void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
-{
-   int i;
-   celt_uint32_t mem;
-   
-   mem = nc[0];
-   nc[0] = nc0;
-   for (i=1;i<len;i++)
-   {
-      celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
-      mem = nc[i];
-      nc[i] = tmp;
-   }
+/*Computes the next row/column of any recurrence that obeys the relation
+   u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
+  _ui0 is the base case for the new row/column.*/
+static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){
+  celt_uint32_t ui1;
+  int           j;
+  for(j=1;j<_len;j++){
+    ui1=_ui[j]+_ui[j-1]+_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
+  }
+  _ui[j-1]=_ui0;
 }
 
-/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
-   compute ncwrs() for m-1, for all n. Could also be used when m and n are
-   swapped just by changing nc */
-static void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
-{
-   int i;
-   celt_uint32_t mem;
-   
-   mem = nc[0];
-   nc[0] = nc0;
-   for (i=1;i<len;i++)
-   {
-      celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
-      mem = nc[i];
-      nc[i] = tmp;
-   }
+static inline void unext64(celt_uint64_t *_ui,int _len,celt_uint64_t _ui0){
+  celt_uint64_t ui1;
+  int           j;
+  for(j=1;j<_len;j++){
+    ui1=_ui[j]+_ui[j-1]+_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
+  }
+  _ui[j-1]=_ui0;
 }
 
-static void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
-{
-   int i;
-   celt_uint64_t mem;
-   
-   mem = nc[0];
-   nc[0] = nc0;
-   for (i=1;i<len;i++)
-   {
-      celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
-      mem = nc[i];
-      nc[i] = tmp;
-   }
+/*Computes the previous row/column of any recurrence that obeys the relation
+   u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1].
+  _ui0 is the base case for the new row/column.*/
+static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){
+  celt_uint32_t ui1;
+  int           j;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]-_ui[j-1]-_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
+  }
+  _ui[j-1]=_ui0;
 }
 
-static void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
-{
-   int i;
-   celt_uint64_t mem;
-   
-   mem = nc[0];
-   nc[0] = nc0;
-   for (i=1;i<len;i++)
-   {
-      celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
-      mem = nc[i];
-      nc[i] = tmp;
-   }
+static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){
+  celt_uint64_t ui1;
+  int           j;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]-_ui[j-1]-_ui0;
+    _ui[j-1]=_ui0;
+    _ui0=ui1;
+  }
+  _ui[j-1]=_ui0;
 }
 
-/*Returns the numer of ways of choosing _m elements from a set of size _n with
-   replacement when a sign bit is needed for each unique element.*/
-celt_uint32_t ncwrs(int _n,int _m)
-{
-   int i;
-   celt_uint32_t ret;
-   VARDECL(celt_uint32_t, nc);
-   SAVE_STACK;
-   ALLOC(nc,_n+1, celt_uint32_t);
-   for (i=0;i<_n+1;i++)
-      nc[i] = 1;
-   for (i=0;i<_m;i++)
-      next_ncwrs32(nc, _n+1, 0);
-   ret = nc[_n];
-   RESTORE_STACK;
-   return ret;
+/*Returns the number of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.
+  On input, _u should be initialized to column (_m-1) of U(n,m).
+  On exit, _u will be initialized to column _m of U(n,m).*/
+celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
+  celt_uint32_t ret;
+  celt_uint32_t ui0;
+  celt_uint32_t ui1;
+  int           j;
+  ret=ui0=2;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]+_ui[j-1]+ui0;
+    _ui[j-1]=ui0;
+    ui0=ui1;
+    ret+=ui0;
+  }
+  _ui[j-1]=ui0;
+  return ret;
 }
 
-/*Returns the numer of ways of choosing _m elements from a set of size _n with
-   replacement when a sign bit is needed for each unique element.*/
-celt_uint64_t ncwrs64(int _n,int _m)
-{
-   int i;
-   celt_uint64_t ret;
-   VARDECL(celt_uint64_t, nc);
-   SAVE_STACK;
-   ALLOC(nc,_n+1, celt_uint64_t);
-   for (i=0;i<_n+1;i++)
-      nc[i] = 1;
-   for (i=0;i<_m;i++)
-      next_ncwrs64(nc, _n+1, 0);
-   ret = nc[_n];
-   RESTORE_STACK;
-   return ret;
+celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){
+  celt_uint64_t ret;
+  celt_uint64_t ui0;
+  celt_uint64_t ui1;
+  int           j;
+  ret=ui0=2;
+  for(j=1;j<_n;j++){
+    ui1=_ui[j]+_ui[j-1]+ui0;
+    _ui[j-1]=ui0;
+    ui0=ui1;
+    ret+=ui0;
+  }
+  _ui[j-1]=ui0;
+  return ret;
 }
 
+/*Returns the number of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.
+  On exit, _u will be initialized to column _m of U(n,m).*/
+celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
+  int k;
+  memset(_u,0,_n*sizeof(*_u));
+  if(_m<=0)return 1;
+  if(_n<=0)return 0;
+  for(k=1;k<_m;k++)unext32(_u,_n,2);
+  return ncwrs_unext32(_n,_u);
+}
+
+celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){
+  int k;
+  memset(_u,0,_n*sizeof(*_u));
+  if(_m<=0)return 1;
+  if(_n<=0)return 0;
+  for(k=1;k<_m;k++)unext64(_u,_n,2);
+  return ncwrs_unext64(_n,_u);
+}
 
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
-  _x:      Returns the combination with elements sorted in ascending order.
-  _s:      Returns the associated sign bits.*/
-void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){
+  _x: Returns the combination with elements sorted in ascending order.
+  _s: Returns the associated sign bits.
+  _u: Temporary storage already initialized to column _m of U(n,m).
+      Its contents will be overwritten.*/
+void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
   int j;
   int k;
-  VARDECL(celt_uint32_t, nc);
-  SAVE_STACK;
-  ALLOC(nc,_n+1, celt_uint32_t);
-  for (j=0;j<_n+1;j++)
-    nc[j] = 1;
-  for (k=0;k<_m-1;k++)
-    next_ncwrs32(nc, _n+1, 0);
   for(k=j=0;k<_m;k++){
-    celt_uint32_t pn, p, t;
-    /*p=ncwrs(_n-j,_m-k-1);
-    pn=ncwrs(_n-j-1,_m-k-1);*/
-    p=nc[_n-j];
-    pn=nc[_n-j-1];
-    p+=pn;
+    celt_uint32_t p;
+    celt_uint32_t t;
+    p=_u[_n-j-1];
     if(k>0){
       t=p>>1;
       if(t<=_i||_s[k-1])_i+=t;
@@ -171,89 +174,23 @@ void cwrsi(int _n,int _m,celt_uint32_t _i,int * restrict _x,int * restrict _s){
     while(p<=_i){
       _i-=p;
       j++;
-      p=pn;
-      /*pn=ncwrs(_n-j-1,_m-k-1);*/
-      pn=nc[_n-j-1];
-      p+=pn;
+      p=_u[_n-j-1];
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
-    if (k<_m-2)
-      prev_ncwrs32(nc, _n+1, 0);
-    else
-      prev_ncwrs32(nc, _n+1, 1);
-  }
-  RESTORE_STACK;
-}
-
-/*Returns the index of the given combination of _m elements chosen from a set
-   of size _n with associated sign bits.
-  _x:      The combination with elements sorted in ascending order.
-  _s:      The associated sign bits.*/
-celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
-  celt_uint32_t i;
-  int      j;
-  int      k;
-  VARDECL(celt_uint32_t, nc);
-  SAVE_STACK;
-  ALLOC(nc,_n+1, celt_uint32_t);
-  for (j=0;j<_n+1;j++)
-    nc[j] = 1;
-  for (k=0;k<_m;k++)
-    next_ncwrs32(nc, _n+1, 0);
-  if (bound)
-    *bound = nc[_n];
-  i=0;
-  for(k=j=0;k<_m;k++){
-    celt_uint32_t pn;
-    celt_uint32_t p;
-    if (k<_m-1)
-      prev_ncwrs32(nc, _n+1, 0);
-    else
-      prev_ncwrs32(nc, _n+1, 1);
-    /*p=ncwrs(_n-j,_m-k-1);
-    pn=ncwrs(_n-j-1,_m-k-1);*/
-    p=nc[_n-j];
-    pn=nc[_n-j-1];
-    p+=pn;
-    if(k>0)p>>=1;
-    while(j<_x[k]){
-      i+=p;
-      j++;
-      p=pn;
-      /*pn=ncwrs(_n-j-1,_m-k-1);*/
-      pn=nc[_n-j-1];
-      p+=pn;
-    }
-    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
+    uprev32(_u,_n-j,2);
   }
-  RESTORE_STACK;
-  return i;
 }
 
-/*Returns the _i'th combination of _m elements chosen from a set of size _n
-   with associated sign bits.
-  _x:      Returns the combination with elements sorted in ascending order.
-  _s:      Returns the associated sign bits.*/
-void cwrsi64(int _n,int _m,celt_uint64_t _i,int * restrict _x,int * restrict _s){
+void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,celt_uint64_t *_u){
   int j;
   int k;
-  VARDECL(celt_uint64_t, nc);
-  SAVE_STACK;
-  ALLOC(nc,_n+1, celt_uint64_t);
-  for (j=0;j<_n+1;j++)
-    nc[j] = 1;
-  for (k=0;k<_m-1;k++)
-    next_ncwrs64(nc, _n+1, 0);
   for(k=j=0;k<_m;k++){
-    celt_uint64_t pn, p, t;
-    /*p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);*/
-    p=nc[_n-j];
-    pn=nc[_n-j-1];
-    p+=pn;
+    celt_uint64_t p;
+    celt_uint64_t t;
+    p=_u[_n-j-1];
     if(k>0){
       t=p>>1;
       if(t<=_i||_s[k-1])_i+=t;
@@ -261,65 +198,61 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,int * restrict _x,int * restrict _s)
     while(p<=_i){
       _i-=p;
       j++;
-      p=pn;
-      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
-      pn=nc[_n-j-1];
-      p+=pn;
+      p=_u[_n-j-1];
     }
     t=p>>1;
     _s[k]=_i>=t;
     _x[k]=j;
     if(_s[k])_i-=t;
-    if (k<_m-2)
-      prev_ncwrs64(nc, _n+1, 0);
-    else
-      prev_ncwrs64(nc, _n+1, 1);
+    uprev64(_u,_n-j,2);
   }
-  RESTORE_STACK;
 }
 
 /*Returns the index of the given combination of _m elements chosen from a set
    of size _n with associated sign bits.
-  _x:      The combination with elements sorted in ascending order.
-  _s:      The associated sign bits.*/
-celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
+  _x: The combination with elements sorted in ascending order.
+  _s: The associated sign bits.
+  _u: Temporary storage already initialized to column _m of U(n,m).
+      Its contents will be overwritten.*/
+celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
+ celt_uint32_t *_u){
+  celt_uint32_t i;
+  int           j;
+  int           k;
+  i=0;
+  for(k=j=0;k<_m;k++){
+    celt_uint32_t p;
+    p=_u[_n-j-1];
+    if(k>0)p>>=1;
+    while(j<_x[k]){
+      i+=p;
+      j++;
+      p=_u[_n-j-1];
+    }
+    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
+    uprev32(_u,_n-j,2);
+  }
+  return i;
+}
+
+celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
+ celt_uint64_t *_u){
   celt_uint64_t i;
   int           j;
   int           k;
-  VARDECL(celt_uint64_t, nc);
-  SAVE_STACK;
-  ALLOC(nc,_n+1, celt_uint64_t);
-  for (j=0;j<_n+1;j++)
-    nc[j] = 1;
-  for (k=0;k<_m;k++)
-    next_ncwrs64(nc, _n+1, 0);
-  if (bound)
-     *bound = nc[_n];
   i=0;
   for(k=j=0;k<_m;k++){
-    celt_uint64_t pn;
     celt_uint64_t p;
-    if (k<_m-1)
-      prev_ncwrs64(nc, _n+1, 0);
-    else
-      prev_ncwrs64(nc, _n+1, 1);
-    /*p=ncwrs64(_n-j,_m-k-1);
-    pn=ncwrs64(_n-j-1,_m-k-1);*/
-    p=nc[_n-j];
-    pn=nc[_n-j-1];
-    p+=pn;
+    p=_u[_n-j-1];
     if(k>0)p>>=1;
     while(j<_x[k]){
       i+=p;
       j++;
-      p=pn;
-      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
-      pn=nc[_n-j-1];
-      p+=pn;
+      p=_u[_n-j-1];
     }
     if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
+    uprev64(_u,_n-j,2);
   }
-  RESTORE_STACK;
   return i;
 }
 
@@ -362,47 +295,83 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
   }
 }
 
+static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
+ ec_enc *_enc){
+  VARDECL(celt_uint32_t,u);
+  celt_uint32_t nc;
+  celt_uint32_t i;
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint32_t);
+  nc=ncwrs_u32(_n,_m,u);
+  i=icwrs32(_n,_m,_x,_s,u);
+  ec_enc_uint(_enc,i,nc);
+  RESTORE_STACK;
+}
+
+static inline void encode_comb64(int _n,int _m,const int *_x,const int *_s,
+ ec_enc *_enc){
+  VARDECL(celt_uint64_t,u);
+  celt_uint64_t nc;
+  celt_uint64_t i;
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint64_t);
+  nc=ncwrs_u64(_n,_m,u);
+  i=icwrs64(_n,_m,_x,_s,u);
+  ec_enc_uint64(_enc,i,nc);
+  RESTORE_STACK;
+}
+
 void encode_pulses(int *_y, int N, int K, ec_enc *enc)
 {
    VARDECL(int, comb);
    VARDECL(int, signs);
    SAVE_STACK;
-   
+
    ALLOC(comb, K, int);
    ALLOC(signs, K, int);
-   
+
    pulse2comb(N, K, comb, signs, _y);
    /* Simple heuristic to figure out whether it fits in 32 bits */
-   if((N+4)*(K+4)<250 || EC_ILOG(N)*K<31)
+   if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
    {
-      celt_uint32_t bound, id;
-      id = icwrs(N, K, comb, signs, &bound);
-      ec_enc_uint(enc,id,bound);
+      encode_comb32(N, K, comb, signs, enc);
    } else {
-      celt_uint64_t bound, id;
-      id = icwrs64(N, K, comb, signs, &bound);
-      ec_enc_uint64(enc,id,bound);
+      encode_comb64(N, K, comb, signs, enc);
    }
    RESTORE_STACK;
 }
 
+static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
+  VARDECL(celt_uint32_t,u);
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint32_t);
+  cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
+  RESTORE_STACK;
+}
+
+static inline void decode_comb64(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
+  VARDECL(celt_uint64_t,u);
+  SAVE_STACK;
+  ALLOC(u,_n,celt_uint64_t);
+  cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_x,_s,u);
+  RESTORE_STACK;
+}
+
 void decode_pulses(int *_y, int N, int K, ec_dec *dec)
 {
    VARDECL(int, comb);
    VARDECL(int, signs);
    SAVE_STACK;
-   
+
    ALLOC(comb, K, int);
    ALLOC(signs, K, int);
    /* Simple heuristic to figure out whether it fits in 32 bits */
-   if((N+4)*(K+4)<250 || EC_ILOG(N)*K<31)
+   if((N+4)*(K+4)<250 || (celt_ilog2(N)+1)*K<31)
    {
-      cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
-      comb2pulse(N, K, _y, comb, signs);
+      decode_comb32(N, K, comb, signs, dec);
    } else {
-      cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
-      comb2pulse(N, K, _y, comb, signs);
+      decode_comb64(N, K, comb, signs, dec);
    }
+   comb2pulse(N, K, _y, comb, signs);
    RESTORE_STACK;
 }
-