fixed-point: unquant_energy_mono() has received the fixed-point code from
[opus.git] / libcelt / cwrs.c
index 19e9731..5d72707 100644 (file)
@@ -1,21 +1,21 @@
-/* (C) 2007 Timothy Terriberry
-*/
+/* (C) 2007 Timothy B. Terriberry
+   (C) 2008 Jean-Marc Valin */
 /*
    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
    are met:
-   
+
    - Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.
-   
+
    - Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in the
    documentation and/or other materials provided with the distribution.
-   
+
    - Neither the name of the Xiph.org Foundation nor the names of its
    contributors may be used to endorse or promote products derived from
    this software without specific prior written permission.
-   
+
    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
 
+/* Functions for encoding and decoding pulse vectors. For more details, see:
+   http://people.xiph.org/~tterribe/notes/cwrs.html
+*/
 
-/*#include <stdio.h>*/
-#include <stdlib.h>
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
 
+#include <stdlib.h>
 #include "cwrs.h"
 
-/*Returns the numer of ways of choosing _m elements from a set of size _n with
-   replacement when a sign bit is needed for each unique element.*/
-#if 0
-static unsigned ncwrs(int _n,int _m){
-  static unsigned c[32][32];
-  if(_n<0||_m<0)return 0;
-  if(!c[_n][_m]){
-    if(_m<=0)c[_n][_m]=1;
-    else if(_n>0)c[_n][_m]=ncwrs(_n-1,_m)+ncwrs(_n,_m-1)+ncwrs(_n-1,_m-1);
-  }
-  return c[_n][_m];
+/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
+   compute ncwrs() for m+1, for all n. Could also be used when m and n are
+   swapped just by changing nc */
+static void next_ncwrs32(celt_uint32_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint32_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint32_t tmp = nc[i]+nc[i-1]+mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
 }
 
-#else
+/* Knowing ncwrs() for a fixed number of pulses m and for all vector sizes n,
+   compute ncwrs() for m-1, for all n. Could also be used when m and n are
+   swapped just by changing nc */
+static void prev_ncwrs32(celt_uint32_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint32_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint32_t tmp = nc[i]-nc[i-1]-mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
+}
 
-/*Returns the greatest common divisor of _a and _b.*/
-static unsigned gcd(unsigned _a,unsigned _b){
-  unsigned r;
-  while(_b){
-    r=_a%_b;
-    _a=_b;
-    _b=r;
-  }
-  return _a;
+static void next_ncwrs64(celt_uint64_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint64_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint64_t tmp = nc[i]+nc[i-1]+mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
 }
 
-/*Returns _a*b/_d, under the assumption that the result is an integer, avoiding
-   overflow.
-  It is assumed, but not required, that _b is smaller than _a.*/
-static unsigned umuldiv(unsigned _a,unsigned _b,unsigned _d){
-  unsigned d;
-  d=gcd(_b,_d);
-  return (_a/(_d/d))*(_b/d);
+static void prev_ncwrs64(celt_uint64_t *nc, int len, int nc0)
+{
+   int i;
+   celt_uint64_t mem;
+   
+   mem = nc[0];
+   nc[0] = nc0;
+   for (i=1;i<len;i++)
+   {
+      celt_uint64_t tmp = nc[i]-nc[i-1]-mem;
+      mem = nc[i];
+      nc[i] = tmp;
+   }
 }
 
-unsigned ncwrs(int _n,int _m){
-  unsigned ret;
-  unsigned f;
-  unsigned d;
-  int      i;
-  if(_n<0||_m<0)return 0;
-  if(_m==0)return 1;
-  if(_n==0)return 0;
-  ret=0;
-  f=_n;
-  d=1;
-  for(i=1;i<=_m;i++){
-    ret+=f*d<<i;
-#if 0
-    f=umuldiv(f,_n-i,i+1);
-    d=umuldiv(d,_m-i,i);
-#else
-    f=(f*(_n-i))/(i+1);
-    d=(d*(_m-i))/i;
-#endif
-  }
-  return ret;
+/*Returns the numer of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.*/
+celt_uint32_t ncwrs(int _n,int _m)
+{
+   int i;
+   VARDECL(celt_uint32_t *nc);
+   ALLOC(nc,_n+1, celt_uint32_t);
+   for (i=0;i<_n+1;i++)
+      nc[i] = 1;
+   for (i=0;i<_m;i++)
+      next_ncwrs32(nc, _n+1, 0);
+   return nc[_n];
+}
+
+/*Returns the numer of ways of choosing _m elements from a set of size _n with
+   replacement when a sign bit is needed for each unique element.*/
+celt_uint64_t ncwrs64(int _n,int _m)
+{
+   int i;
+   VARDECL(celt_uint64_t *nc);
+   ALLOC(nc,_n+1, celt_uint64_t);
+   for (i=0;i<_n+1;i++)
+      nc[i] = 1;
+   for (i=0;i<_m;i++)
+      next_ncwrs64(nc, _n+1, 0);
+   return nc[_n];
 }
-#endif
+
 
 /*Returns the _i'th combination of _m elements chosen from a set of size _n
    with associated sign bits.
   _x:      Returns the combination with elements sorted in ascending order.
   _s:      Returns the associated sign bits.*/
-void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
-  unsigned pn;
-  int      j;
-  int      k;
-  pn=ncwrs(_n-1,_m);
+void cwrsi(int _n,int _m,celt_uint32_t _i,int *_x,int *_s){
+  int j;
+  int k;
+  VARDECL(celt_uint32_t *nc);
+  ALLOC(nc,_n+1, celt_uint32_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m-1;k++)
+    next_ncwrs32(nc, _n+1, 0);
   for(k=j=0;k<_m;k++){
-    unsigned pp;
-    unsigned p;
-    unsigned t;
-    pp=0;
-    p=ncwrs(_n-j,_m-k)-pn;
+    celt_uint32_t pn, p, t;
+    /*p=ncwrs(_n-j,_m-k-1);
+    pn=ncwrs(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
+    p+=pn;
     if(k>0){
       t=p>>1;
       if(t<=_i||_s[k-1])_i+=t;
     }
-    pn=ncwrs(_n-j-1,_m-k-1);
     while(p<=_i){
-      pp=p;
+      _i-=p;
       j++;
-      p+=pn;
-      pn=ncwrs(_n-j-1,_m-k-1);
+      p=pn;
+      /*pn=ncwrs(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
-    t=p-pp>>1;
-    _s[k]=_i-pp>=t;
+    t=p>>1;
+    _s[k]=_i>=t;
     _x[k]=j;
-    _i-=pp;
     if(_s[k])_i-=t;
+    if (k<_m-2)
+      prev_ncwrs32(nc, _n+1, 0);
+    else
+      prev_ncwrs32(nc, _n+1, 1);
   }
 }
 
@@ -134,29 +182,127 @@ void cwrsi(int _n,int _m,unsigned _i,int *_x,int *_s){
    of size _n with associated sign bits.
   _x:      The combination with elements sorted in ascending order.
   _s:      The associated sign bits.*/
-unsigned icwrs(int _n,int _m,const int *_x,const int *_s){
-  unsigned pn;
-  unsigned i;
+celt_uint32_t icwrs(int _n,int _m,const int *_x,const int *_s, celt_uint32_t *bound){
+  celt_uint32_t i;
   int      j;
   int      k;
+  VARDECL(celt_uint32_t *nc);
+  ALLOC(nc,_n+1, celt_uint32_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m;k++)
+    next_ncwrs32(nc, _n+1, 0);
+  if (bound)
+    *bound = nc[_n];
   i=0;
-  pn=ncwrs(_n-1,_m);
   for(k=j=0;k<_m;k++){
-    unsigned pp;
-    unsigned p;
-    pp=0;
-    p=ncwrs(_n-j,_m-k)-pn;
+    celt_uint32_t pn;
+    celt_uint32_t p;
+    if (k<_m-1)
+      prev_ncwrs32(nc, _n+1, 0);
+    else
+      prev_ncwrs32(nc, _n+1, 1);
+    /*p=ncwrs(_n-j,_m-k-1);
+    pn=ncwrs(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
+    p+=pn;
     if(k>0)p>>=1;
-    pn=ncwrs(_n-j-1,_m-k-1);
     while(j<_x[k]){
-      pp=p;
+      i+=p;
+      j++;
+      p=pn;
+      /*pn=ncwrs(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
+      p+=pn;
+    }
+    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
+  }
+  return i;
+}
+
+/*Returns the _i'th combination of _m elements chosen from a set of size _n
+   with associated sign bits.
+  _x:      Returns the combination with elements sorted in ascending order.
+  _s:      Returns the associated sign bits.*/
+void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s){
+  int j;
+  int k;
+  VARDECL(celt_uint64_t *nc);
+  ALLOC(nc,_n+1, celt_uint64_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m-1;k++)
+    next_ncwrs64(nc, _n+1, 0);
+  for(k=j=0;k<_m;k++){
+    celt_uint64_t pn, p, t;
+    /*p=ncwrs64(_n-j,_m-k-1);
+    pn=ncwrs64(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
+    p+=pn;
+    if(k>0){
+      t=p>>1;
+      if(t<=_i||_s[k-1])_i+=t;
+    }
+    while(p<=_i){
+      _i-=p;
       j++;
+      p=pn;
+      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
-      pn=ncwrs(_n-j-1,_m-k-1);
+    }
+    t=p>>1;
+    _s[k]=_i>=t;
+    _x[k]=j;
+    if(_s[k])_i-=t;
+    if (k<_m-2)
+      prev_ncwrs64(nc, _n+1, 0);
+    else
+      prev_ncwrs64(nc, _n+1, 1);
+  }
+}
+
+/*Returns the index of the given combination of _m elements chosen from a set
+   of size _n with associated sign bits.
+  _x:      The combination with elements sorted in ascending order.
+  _s:      The associated sign bits.*/
+celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s, celt_uint64_t *bound){
+  celt_uint64_t i;
+  int           j;
+  int           k;
+  VARDECL(celt_uint64_t *nc);
+  ALLOC(nc,_n+1, celt_uint64_t);
+  for (j=0;j<_n+1;j++)
+    nc[j] = 1;
+  for (k=0;k<_m;k++)
+    next_ncwrs64(nc, _n+1, 0);
+  if (bound)
+     *bound = nc[_n];
+  i=0;
+  for(k=j=0;k<_m;k++){
+    celt_uint64_t pn;
+    celt_uint64_t p;
+    if (k<_m-1)
+      prev_ncwrs64(nc, _n+1, 0);
+    else
+      prev_ncwrs64(nc, _n+1, 1);
+    /*p=ncwrs64(_n-j,_m-k-1);
+    pn=ncwrs64(_n-j-1,_m-k-1);*/
+    p=nc[_n-j];
+    pn=nc[_n-j-1];
+    p+=pn;
+    if(k>0)p>>=1;
+    while(j<_x[k]){
+      i+=p;
+      j++;
+      p=pn;
+      /*pn=ncwrs64(_n-j-1,_m-k-1);*/
+      pn=nc[_n-j-1];
       p+=pn;
     }
-    i+=pp;
-    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p-pp>>1;
+    if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
   }
   return i;
 }
@@ -200,47 +346,42 @@ void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
   }
 }
 
-/*
-#define NMAX (10)
-#define MMAX (9)
+void encode_pulses(int *_y, int N, int K, ec_enc *enc)
+{
+   VARDECL(int *comb);
+   VARDECL(int *signs);
+   
+   ALLOC(comb, K, int);
+   ALLOC(signs, K, int);
+   
+   pulse2comb(N, K, comb, signs, _y);
+   /* Go with 32-bit path if we're sure we can */
+   if (N<=13 && K<=13)
+   {
+      celt_uint32_t bound, id;
+      id = icwrs(N, K, comb, signs, &bound);
+      ec_enc_uint(enc,id,bound);
+   } else {
+      celt_uint64_t bound, id;
+      id = icwrs64(N, K, comb, signs, &bound);
+      ec_enc_uint64(enc,id,bound);
+   }
+}
 
-int main(int _argc,char **_argv){
-  int n;
-  for(n=0;n<=NMAX;n++){
-    int m;
-    for(m=0;m<=MMAX;m++){
-      unsigned nc;
-      unsigned i;
-      nc=ncwrs(n,m);
-      for(i=0;i<nc;i++){
-        int x[MMAX];
-        int s[MMAX];
-        int x2[MMAX];
-        int s2[MMAX];
-        int y[NMAX];
-        int j;
-        int k;
-        cwrsi(n,m,i,x,s);
-        printf("%6u of %u:",i,nc);
-        for(k=0;k<m;k++){
-          printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
-        }
-        printf(" ->");
-        if(icwrs(n,m,x,s)!=i){
-          fprintf(stderr,"Combination-index mismatch.\n");
-        }
-        comb2pulse(n,m,y,x,s);
-        for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
-        printf("\n");
-        pulse2comb(n,m,x2,s2,y);
-        for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
-          fprintf(stderr,"Pulse-combination mismatch.\n");
-          break;
-        }
-      }
-      printf("\n");
-    }
-  }
-  return 0;
+void decode_pulses(int *_y, int N, int K, ec_dec *dec)
+{
+   VARDECL(int *comb);
+   VARDECL(int *signs);
+   
+   ALLOC(comb, K, int);
+   ALLOC(signs, K, int);
+   if (N<=13 && K<=13)
+   {
+      cwrsi(N, K, ec_dec_uint(dec, ncwrs(N, K)), comb, signs);
+      comb2pulse(N, K, _y, comb, signs);
+   } else {
+      cwrsi64(N, K, ec_dec_uint64(dec, ncwrs64(N, K)), comb, signs);
+      comb2pulse(N, K, _y, comb, signs);
+   }
 }
-*/
+