Implemented a cleaner way to detect whether CWRS codebooks fit in 32 or 64 bits
[opus.git] / libcelt / rate.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include "modes.h"
38 #include "cwrs.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 #include "entcode.h"
43 #include "rate.h"
44
45 #define BITRES 4
46 #define BITROUND 8
47 #define BITOVERFLOW 30000
48
49 #ifndef STATIC_MODES
50 #if 0
51 static int log2_frac(ec_uint32 val, int frac)
52 {
53    int i;
54    /* EC_ILOG() actually returns log2()+1, go figure */
55    int L = EC_ILOG(val)-1;
56    /*printf ("in: %d %d ", val, L);*/
57    if (L>14)
58       val >>= L-14;
59    else if (L<14)
60       val <<= 14-L;
61    L <<= frac;
62    /*printf ("%d\n", val);*/
63    for (i=0;i<frac;i++)
64    {
65       val = (val*val) >> 15;
66       /*printf ("%d\n", val);*/
67       if (val > 16384)
68          L |= (1<<(frac-i-1));
69       else   
70          val <<= 1;
71    }
72    return L;
73 }
74 #endif
75
76 static int log2_frac64(ec_uint64 val, int frac)
77 {
78    int i;
79    /* EC_ILOG64() actually returns log2()+1, go figure */
80    int L = EC_ILOG64(val)-1;
81    /*printf ("in: %d %d ", val, L);*/
82    if (L>14)
83       val >>= L-14;
84    else if (L<14)
85       val <<= 14-L;
86    L <<= frac;
87    /*printf ("%d\n", val);*/
88    for (i=0;i<frac;i++)
89    {
90       val = (val*val) >> 15;
91       /*printf ("%d\n", val);*/
92       if (val > 16384)
93          L |= (1<<(frac-i-1));
94       else   
95          val <<= 1;
96    }
97    return L;
98 }
99
100 celt_int16_t **compute_alloc_cache(CELTMode *m, int C)
101 {
102    int i, prevN;
103    celt_int16_t **bits;
104    const celt_int16_t *eBands = m->eBands;
105
106    bits = celt_alloc(m->nbEBands*sizeof(celt_int16_t*));
107    
108    prevN = -1;
109    for (i=0;i<m->nbEBands;i++)
110    {
111       int N = C*(eBands[i+1]-eBands[i]);
112       if (N == prevN && eBands[i] < m->pitchEnd)
113       {
114          bits[i] = bits[i-1];
115       } else {
116          int j;
117          VARDECL(celt_uint64_t, u);
118          SAVE_STACK;
119          ALLOC(u, N, celt_uint64_t);
120          /* FIXME: We could save memory here */
121          bits[i] = celt_alloc(MAX_PULSES*sizeof(celt_int16_t));
122          for (j=0;j<MAX_PULSES;j++)
123          {
124             int done = 0;
125             int pulses = j;
126             /* For bands where there's no pitch, id 1 corresponds to intra prediction 
127             with no pulse. id 2 means intra prediction with one pulse, and so on.*/
128             if (eBands[i] >= m->pitchEnd)
129                pulses -= 1;
130             if (pulses < 0)
131                bits[i][j] = 0;
132             else {
133                celt_uint64_t nc;
134                if (!fits_in64(N, pulses))
135                   break;
136                nc=pulses?ncwrs_unext64(N, u):ncwrs_u64(N, 0, u);
137                bits[i][j] = log2_frac64(nc,BITRES);
138                /* Add the intra-frame prediction sign bit */
139                if (eBands[i] >= m->pitchEnd)
140                   bits[i][j] += (1<<BITRES);
141             }
142             if (done)
143                break;
144          }
145          for (;j<MAX_PULSES;j++)
146             bits[i][j] = BITOVERFLOW;
147          prevN = N;
148          RESTORE_STACK;
149       }
150    }
151    return bits;
152 }
153
154 #endif /* !STATIC_MODES */
155
156 static inline int bits2pulses(const CELTMode *m, const celt_int16_t *cache, int bits)
157 {
158    int i;
159    int lo, hi;
160    lo = 0;
161    hi = MAX_PULSES-1;
162    
163    /* Instead of using the "bisection condition" we use a fixed number of 
164       iterations because it should be faster */
165    /*while (hi-lo != 1)*/
166    for (i=0;i<LOG_MAX_PULSES;i++)
167    {
168       int mid = (lo+hi)>>1;
169       /* OPT: Make sure this is implemented with a conditional move */
170       if (cache[mid] >= bits)
171          hi = mid;
172       else
173          lo = mid;
174    }
175    if (bits-cache[lo] <= cache[hi]-bits)
176       return lo;
177    else
178       return hi;
179 }
180
181 static int vec_bits2pulses(const CELTMode *m, const celt_int16_t * const *cache, int *bits, int *pulses, int len)
182 {
183    int i;
184    int sum=0;
185
186    for (i=0;i<len;i++)
187    {
188       pulses[i] = bits2pulses(m, cache[i], bits[i]);
189       sum += cache[i][pulses[i]];
190    }
191    /*printf ("sum = %d\n", sum);*/
192    return sum;
193 }
194
195 static int interp_bits2pulses(const CELTMode *m, const celt_int16_t * const *cache, int *bits1, int *bits2, int total, int *pulses, int len)
196 {
197    int lo, hi, out;
198    int j;
199    VARDECL(int, bits);
200    SAVE_STACK;
201    ALLOC(bits, len, int);
202    lo = 0;
203    hi = 1<<BITRES;
204    while (hi-lo != 1)
205    {
206       int mid = (lo+hi)>>1;
207       for (j=0;j<len;j++)
208          bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
209       if (vec_bits2pulses(m, cache, bits, pulses, len) > total<<BITRES)
210          hi = mid;
211       else
212          lo = mid;
213    }
214    /*printf ("interp bisection gave %d\n", lo);*/
215    for (j=0;j<len;j++)
216       bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
217    out = vec_bits2pulses(m, cache, bits, pulses, len);
218    /* Do some refinement to use up all bits. In the first pass, we can only add pulses to 
219       bands that are under their allocated budget. In the second pass, anything goes */
220    for (j=0;j<len;j++)
221    {
222       if (cache[j][pulses[j]] < bits[j] && pulses[j]<MAX_PULSES-1)
223       {
224          if (out+cache[j][pulses[j]+1]-cache[j][pulses[j]] <= total<<BITRES)
225          {
226             out = out+cache[j][pulses[j]+1]-cache[j][pulses[j]];
227             pulses[j] += 1;
228          }
229       }
230    }
231    while(1)
232    {
233       int incremented = 0;
234       for (j=0;j<len;j++)
235       {
236          if (pulses[j]<MAX_PULSES-1)
237          {
238             if (out+cache[j][pulses[j]+1]-cache[j][pulses[j]] <= total<<BITRES)
239             {
240                out = out+cache[j][pulses[j]+1]-cache[j][pulses[j]];
241                pulses[j] += 1;
242                incremented = 1;
243             }
244          }
245       }
246       if (!incremented)
247             break;
248    }
249    RESTORE_STACK;
250    return (out+BITROUND) >> BITRES;
251 }
252
253 int compute_allocation(const CELTMode *m, int *offsets, const int *stereo_mode, int total, int *pulses)
254 {
255    int lo, hi, len, ret, i;
256    VARDECL(int, bits1);
257    VARDECL(int, bits2);
258    VARDECL(const celt_int16_t*, cache);
259    SAVE_STACK;
260    
261    len = m->nbEBands;
262    ALLOC(bits1, len, int);
263    ALLOC(bits2, len, int);
264    ALLOC(cache, len, const celt_int16_t*);
265    
266    if (m->nbChannels==2)
267    {
268       for (i=0;i<len;i++)
269       {
270          if (stereo_mode[i]==0)
271             cache[i] = m->bits_stereo[i];
272          else
273             cache[i] = m->bits[i];
274       }
275    } else {
276       for (i=0;i<len;i++)
277          cache[i] = m->bits[i];
278    }
279    
280    lo = 0;
281    hi = m->nbAllocVectors - 1;
282    while (hi-lo != 1)
283    {
284       int j;
285       int mid = (lo+hi) >> 1;
286       for (j=0;j<len;j++)
287       {
288          bits1[j] = (m->allocVectors[mid*len+j] + offsets[j])<<BITRES;
289          if (bits1[j] < 0)
290             bits1[j] = 0;
291          /*printf ("%d ", bits[j]);*/
292       }
293       /*printf ("\n");*/
294       if (vec_bits2pulses(m, cache, bits1, pulses, len) > total<<BITRES)
295          hi = mid;
296       else
297          lo = mid;
298       /*printf ("lo = %d, hi = %d\n", lo, hi);*/
299    }
300    {
301       int j;
302       for (j=0;j<len;j++)
303       {
304          bits1[j] = m->allocVectors[lo*len+j] + offsets[j];
305          bits2[j] = m->allocVectors[hi*len+j] + offsets[j];
306          if (bits1[j] < 0)
307             bits1[j] = 0;
308          if (bits2[j] < 0)
309             bits2[j] = 0;
310       }
311       ret = interp_bits2pulses(m, cache, bits1, bits2, total, pulses, len);
312       RESTORE_STACK;
313       return ret;
314    }
315 }
316