Compute the real maximum required bits for a split.
[opus.git] / libcelt / cwrs.c
index b779578..19681fd 100644 (file)
@@ -298,25 +298,59 @@ static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
   RESTORE_STACK;
 }
 
-int get_required_bits(int N, int K, int frac)
+int get_required_bits32(int N, int K, int frac)
 {
-   int nbits = 0;
-   if (K==0)
-   {
-      nbits = 0;
-   } else if(fits_in32(N,K))
-   {
-      VARDECL(celt_uint32_t,u);
+   int nbits;
+   VARDECL(celt_uint32_t,u);
+   SAVE_STACK;
+   ALLOC(u,K+2,celt_uint32_t);
+   nbits = log2_frac(ncwrs_u32(N,K,u), frac);
+   RESTORE_STACK;
+   return nbits;
+}
+
+void get_required_bits(celt_int16_t *bits,int N, int MAXK, int frac)
+{
+   int k;
+   /*We special case k==0 below, since fits_in32 could reject it for large N.*/
+   celt_assert(MAXK>0);
+   if(fits_in32(N,MAXK-1)){
+      bits[0]=0;
+      /*This could be sped up one heck of a lot if we didn't recompute u in
+         ncwrs_u32 every time.*/
+      for(k=1;k<MAXK;k++)bits[k]=get_required_bits32(N,k,frac);
+   }
+   else{
+      VARDECL(celt_int16_t,n1bits);
+      VARDECL(celt_int16_t,_n2bits);
+      celt_int16_t *n2bits;
       SAVE_STACK;
-      ALLOC(u,K+2,celt_uint32_t);
-      nbits = log2_frac(ncwrs_u32(N,K,u), frac);
-      RESTORE_STACK;
-   } else {
-      nbits = log2_frac(K+1, frac);
-      nbits += get_required_bits(N/2+1, (K+1)/2, frac);
-      nbits += get_required_bits(N/2+1, K/2, frac);
+      ALLOC(n1bits,MAXK,celt_int16_t);
+      ALLOC(_n2bits,MAXK,celt_int16_t);
+      get_required_bits(n1bits,(N+1)/2,MAXK,frac);
+      if(N&1){
+        n2bits=_n2bits;
+        get_required_bits(n2bits,N/2,MAXK,frac);
+      }else{
+        n2bits=n1bits;
+      }
+      bits[0]=0;
+      for(k=1;k<MAXK;k++){
+         if(fits_in32(N,k))bits[k]=get_required_bits32(N,k,frac);
+         else{
+            int worst_bits;
+            int i;
+            worst_bits=0;
+            for(i=0;i<=k;i++){
+               int split_bits;
+               split_bits=n1bits[i]+n2bits[k-i];
+               if(split_bits>worst_bits)worst_bits=split_bits;
+            }
+            bits[k]=log2_frac(k+1,frac)+worst_bits;
+         }
+      }
+   RESTORE_STACK;
    }
-   return nbits;
 }