Move skip coding into interp_bits2pulses().
[opus.git] / libcelt / cwrs.c
index 15135af..843dcd5 100644 (file)
@@ -148,33 +148,6 @@ static inline celt_uint32 imusdiv32even(celt_uint32 _a,celt_uint32 _b,
    (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
 }
 
-/*Compute floor(sqrt(_val)) with exact arithmetic.
-  This has been tested on all possible 32-bit inputs.*/
-static unsigned isqrt32(celt_uint32 _val){
-  unsigned b;
-  unsigned g;
-  int      bshift;
-  /*Uses the second method from
-     http://www.azillionmonkeys.com/qed/sqroot.html
-    The main idea is to search for the largest binary digit b such that
-     (g+b)*(g+b) <= _val, and add it to the solution g.*/
-  g=0;
-  bshift=EC_ILOG(_val)-1>>1;
-  b=1U<<bshift;
-  do{
-    celt_uint32 t;
-    t=((celt_uint32)g<<1)+b<<bshift;
-    if(t<=_val){
-      g+=b;
-      _val-=t;
-    }
-    b>>=1;
-    bshift--;
-  }
-  while(bshift>=0);
-  return g;
-}
-
 #endif /* SMALL_FOOTPRINT */
 
 /*Although derived separately, the pulse vector coding scheme is equivalent to
@@ -296,27 +269,6 @@ static unsigned isqrt32(celt_uint32 _val){
     year=1986
   }*/
 
-/*Determines if V(N,K) fits in a 32-bit unsigned integer.
-  N and K are themselves limited to 15 bits.*/
-int fits_in32(int _n, int _k)
-{
-   static const celt_int16 maxN[15] = {
-      32767, 32767, 32767, 1476, 283, 109,  60,  40,
-       29,  24,  20,  18,  16,  14,  13};
-   static const celt_int16 maxK[15] = {
-      32767, 32767, 32767, 32767, 1172, 238,  95,  53,
-       36,  27,  22,  18,  16,  15,  13};
-   if (_n>=14)
-   {
-      if (_k>=14)
-         return 0;
-      else
-         return _n <= maxN[_k];
-   } else {
-      return _k <= maxK[_n];
-   }
-}
-
 #ifndef SMALL_FOOTPRINT
 
 /*Compute U(1,_k).*/
@@ -446,8 +398,6 @@ static celt_uint32 ncwrs_urow(unsigned _n,unsigned _k,celt_uint32 *_u){
   return _u[_k]+_u[_k+1];
 }
 
-#ifndef SMALL_FOOTPRINT
-
 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
    set of size 1 with associated sign bits.
   _y: Returns the vector of pulses.*/
@@ -457,6 +407,8 @@ static inline void cwrsi1(int _k,celt_uint32 _i,int *_y){
   _y[0]=_k+s^s;
 }
 
+#ifndef SMALL_FOOTPRINT
+
 /*Returns the _i'th combination of _k elements (at most 32767) chosen from a
    set of size 2 with associated sign bits.
   _y: Returns the vector of pulses.*/
@@ -692,123 +644,33 @@ celt_uint32 icwrs(int _n,int _k,celt_uint32 *_nc,const int *_y,
   return i;
 }
 
-
-/*Computes get_required_bits when splitting is required.
-  _left_bits and _right_bits must contain the required bits for the left and
-   right sides of the split, respectively (which themselves may require
-   splitting).*/
-static void get_required_split_bits(celt_int16 *_bits,
- const celt_int16 *_left_bits,const celt_int16 *_right_bits,
- int _n,int _maxk,int _frac){
-  int k;
-  for(k=_maxk;k-->0;){
-    /*If we've reached a k where everything fits in 32 bits, evaluate the
-       remaining required bits directly.*/
-    if(fits_in32(_n,k)){
-      get_required_bits(_bits,_n,k+1,_frac);
-      break;
-    }
-    else{
-      int worst_bits;
-      int i;
-      /*Due to potentially recursive splitting, it's difficult to derive an
-         analytic expression for the location of the worst-case split index.
-        We simply check them all.*/
-      worst_bits=0;
-      for(i=0;i<=k;i++){
-        int split_bits;
-        split_bits=_left_bits[i]+_right_bits[k-i];
-        if(split_bits>worst_bits)worst_bits=split_bits;
-      }
-      _bits[k]=log2_frac(k+1,_frac)+worst_bits;
-    }
-  }
-}
-
-/*Computes get_required_bits for a pair of N values.
-  _n1 and _n2 must either be equal or two consecutive integers.
-  Returns the buffer used to store the required bits for _n2, which is either
-   _bits1 if _n1==_n2 or _bits2 if _n1+1==_n2.*/
-static celt_int16 *get_required_bits_pair(celt_int16 *_bits1,
- celt_int16 *_bits2,celt_int16 *_tmp,int _n1,int _n2,int _maxk,int _frac){
-  celt_int16 *tmp2;
-  /*If we only need a single set of required bits...*/
-  if(_n1==_n2){
-    /*Stop recursing if everything fits.*/
-    if(fits_in32(_n1,_maxk-1))get_required_bits(_bits1,_n1,_maxk,_frac);
-    else{
-      _tmp=get_required_bits_pair(_bits2,_tmp,_bits1,
-       _n1>>1,_n1+1>>1,_maxk,_frac);
-      get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac);
-    }
-    return _bits1;
-  }
-  /*Otherwise we need two distinct sets...*/
-  celt_assert(_n1+1==_n2);
-  /*Stop recursing if everything fits.*/
-  if(fits_in32(_n2,_maxk-1)){
-    get_required_bits(_bits1,_n1,_maxk,_frac);
-    get_required_bits(_bits2,_n2,_maxk,_frac);
-  }
-  /*Otherwise choose an evaluation order that doesn't require extra buffers.*/
-  else if(_n1&1){
-    /*This special case isn't really needed, but can save some work.*/
-    if(fits_in32(_n1,_maxk-1)){
-      tmp2=get_required_bits_pair(_tmp,_bits1,_bits2,
-       _n2>>1,_n2>>1,_maxk,_frac);
-      get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac);
-      get_required_bits(_bits1,_n1,_maxk,_frac);
-    }
-    else{
-      _tmp=get_required_bits_pair(_bits2,_tmp,_bits1,
-       _n1>>1,_n1+1>>1,_maxk,_frac);
-      get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac);
-      get_required_split_bits(_bits2,_tmp,_tmp,_n2,_maxk,_frac);
-    }
-  }
-  else{
-    /*There's no need to special case _n1 fitting by itself, since _n2 requires
-       us to recurse for both values anyway.*/
-    tmp2=get_required_bits_pair(_tmp,_bits1,_bits2,
-     _n2>>1,_n2+1>>1,_maxk,_frac);
-    get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac);
-    get_required_split_bits(_bits1,_tmp,_tmp,_n1,_maxk,_frac);
-  }
-  return _bits2;
-}
-
+#ifndef STATIC_MODES
 void get_required_bits(celt_int16 *_bits,int _n,int _maxk,int _frac){
   int k;
   /*_maxk==0 => there's nothing to do.*/
   celt_assert(_maxk>0);
-  if(fits_in32(_n,_maxk-1)){
-    _bits[0]=0;
-    if(_maxk>1){
-      VARDECL(celt_uint32,u);
-      SAVE_STACK;
-      ALLOC(u,_maxk+1U,celt_uint32);
-      ncwrs_urow(_n,_maxk-1,u);
-      for(k=1;k<_maxk;k++)_bits[k]=log2_frac(u[k]+u[k+1],_frac);
-      RESTORE_STACK;
-    }
+  _bits[0]=0;
+  if (_n==1)
+  {
+    for (k=1;k<=_maxk;k++)
+      _bits[k] = 1<<_frac;
   }
-  else{
-    VARDECL(celt_int16,n1bits);
-    VARDECL(celt_int16,n2bits_buf);
-    celt_int16 *n2bits;
+  else {
+    VARDECL(celt_uint32,u);
     SAVE_STACK;
-    ALLOC(n1bits,_maxk,celt_int16);
-    ALLOC(n2bits_buf,_maxk,celt_int16);
-    n2bits=get_required_bits_pair(n1bits,n2bits_buf,_bits,
-     _n>>1,_n+1>>1,_maxk,_frac);
-    get_required_split_bits(_bits,n1bits,n2bits,_n,_maxk,_frac);
+    ALLOC(u,_maxk+2U,celt_uint32);
+    ncwrs_urow(_n,_maxk,u);
+    for(k=1;k<=_maxk;k++)
+      _bits[k]=log2_frac(u[k]+u[k+1],_frac);
     RESTORE_STACK;
   }
 }
+#endif /* STATIC_MODES */
 
-
-static inline void encode_pulses32(int _n,int _k,const int *_y,ec_enc *_enc){
+void encode_pulses(const int *_y,int _n,int _k,ec_enc *_enc){
   celt_uint32 i;
+  if (_k==0)
+     return;
   switch(_n){
     case 1:{
       i=icwrs1(_y,&_k);
@@ -846,26 +708,14 @@ static inline void encode_pulses32(int _n,int _k,const int *_y,ec_enc *_enc){
   }
 }
 
-void encode_pulses(int *_y, int N, int K, ec_enc *enc)
+void decode_pulses(int *_y,int _n,int _k,ec_dec *_dec)
 {
-   if (K==0) {
-   } else if(fits_in32(N,K))
-   {
-      encode_pulses32(N, K, _y, enc);
-   } else {
-     int i;
-     int count=0;
-     int split;
-     split = (N+1)/2;
-     for (i=0;i<split;i++)
-        count += abs(_y[i]);
-     ec_enc_uint(enc,count,K+1);
-     encode_pulses(_y, split, count, enc);
-     encode_pulses(_y+split, N-split, K-count, enc);
+   if (_k==0) {
+      int i;
+      for (i=0;i<_n;i++)
+         _y[i] = 0;
+      return;
    }
-}
-
-static inline void decode_pulses32(int _n,int _k,int *_y,ec_dec *_dec){
    switch(_n){
     case 1:{
       celt_assert(ncwrs1(_k)==2);
@@ -888,20 +738,3 @@ static inline void decode_pulses32(int _n,int _k,int *_y,ec_dec *_dec){
   }
 }
 
-void decode_pulses(int *_y, int N, int K, ec_dec *dec)
-{
-   if (K==0) {
-      int i;
-      for (i=0;i<N;i++)
-         _y[i] = 0;
-   } else if(fits_in32(N,K))
-   {
-      decode_pulses32(N, K, _y, dec);
-   } else {
-     int split;
-     int count = ec_dec_uint(dec,K+1);
-     split = (N+1)/2;
-     decode_pulses(_y, split, count, dec);
-     decode_pulses(_y+split, N-split, K-count, dec);
-   }
-}