fixed a bunch of bugs in the unified allocation code.
[opus.git] / libcelt / rate.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include "modes.h"
38 #include "cwrs.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 #include "entcode.h"
43 #include "rate.h"
44
45 #define BITRES 4
46 #define BITROUND 8
47 #define BITOVERFLOW 30000
48
49 #ifndef STATIC_MODES
50
51 celt_int16_t **compute_alloc_cache(CELTMode *m, int C)
52 {
53    int i, prevN;
54    celt_int16_t **bits;
55    const celt_int16_t *eBands = m->eBands;
56
57    bits = celt_alloc(m->nbEBands*sizeof(celt_int16_t*));
58    
59    prevN = -1;
60    for (i=0;i<m->nbEBands;i++)
61    {
62       int N = C*(eBands[i+1]-eBands[i]);
63       if (N == prevN && eBands[i] < m->pitchEnd)
64       {
65          bits[i] = bits[i-1];
66       } else {
67          int j;
68          /* FIXME: We could save memory here */
69          bits[i] = celt_alloc(MAX_PULSES*sizeof(celt_int16_t));
70          for (j=0;j<MAX_PULSES;j++)
71          {
72             int pulses = j;
73             /* For bands where there's no pitch, id 1 corresponds to intra prediction 
74             with no pulse. id 2 means intra prediction with one pulse, and so on.*/
75             if (eBands[i] >= m->pitchEnd)
76                pulses -= 1;
77             if (pulses < 0)
78                bits[i][j] = 0;
79             else {
80                bits[i][j] = get_required_bits(N, pulses, BITRES);
81                /* Add the intra-frame prediction sign bit */
82                if (eBands[i] >= m->pitchEnd)
83                   bits[i][j] += (1<<BITRES);
84             }
85          }
86          for (;j<MAX_PULSES;j++)
87             bits[i][j] = BITOVERFLOW;
88          prevN = N;
89       }
90    }
91    return bits;
92 }
93
94 #endif /* !STATIC_MODES */
95
96 static inline int bits2pulses(const CELTMode *m, const celt_int16_t *cache, int bits)
97 {
98    int i;
99    int lo, hi;
100    lo = 0;
101    hi = MAX_PULSES-1;
102    
103    /* Instead of using the "bisection condition" we use a fixed number of 
104       iterations because it should be faster */
105    /*while (hi-lo != 1)*/
106    for (i=0;i<LOG_MAX_PULSES;i++)
107    {
108       int mid = (lo+hi)>>1;
109       /* OPT: Make sure this is implemented with a conditional move */
110       if (cache[mid] >= bits)
111          hi = mid;
112       else
113          lo = mid;
114    }
115    if (bits-cache[lo] <= cache[hi]-bits)
116       return lo;
117    else
118       return hi;
119 }
120
121 static int vec_bits2pulses(const CELTMode *m, const celt_int16_t * const *cache, int *bits, int *pulses, int len)
122 {
123    int i;
124    int sum=0;
125
126    for (i=0;i<len;i++)
127    {
128       pulses[i] = bits2pulses(m, cache[i], bits[i]);
129       sum += cache[i][pulses[i]];
130    }
131    /*printf ("sum = %d\n", sum);*/
132    return sum;
133 }
134
135 static int interp_bits2pulses(const CELTMode *m, const celt_int16_t * const *cache, int *bits1, int *bits2, int *ebits1, int *ebits2, int total, int *pulses, int *ebits, int len)
136 {
137    int esum;
138    int lo, hi, out;
139    int j;
140    VARDECL(int, bits);
141    const int C = CHANNELS(m);
142    SAVE_STACK;
143    ALLOC(bits, len, int);
144    lo = 0;
145    hi = 1<<BITRES;
146    while (hi-lo != 1)
147    {
148       int mid = (lo+hi)>>1;
149       esum = 0;
150       for (j=0;j<len;j++)
151       {
152          ebits[j] = (((1<<BITRES)-mid)*ebits1[j] + mid*ebits2[j] + (1<<(BITRES-1)))>>BITRES;
153          esum += ebits[j];
154       }
155       for (j=0;j<len;j++)
156          bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
157       if (vec_bits2pulses(m, cache, bits, pulses, len) > (total-C*esum)<<BITRES)
158          hi = mid;
159       else
160          lo = mid;
161    }
162    esum = 0;
163    /*printf ("interp bisection gave %d\n", lo);*/
164    for (j=0;j<len;j++)
165    {
166       ebits[j] = (((1<<BITRES)-lo)*ebits1[j] + lo*ebits2[j] + (1<<(BITRES-1)))>>BITRES;
167       esum += ebits[j];
168    }
169    for (j=0;j<len;j++)
170       bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
171    out = vec_bits2pulses(m, cache, bits, pulses, len);
172    /*printf ("left to allocate: %d\n", total-C*esum-(out>>BITRES));*/
173    /* Do some refinement to use up all bits. In the first pass, we can only add pulses to 
174       bands that are under their allocated budget. In the second pass, anything goes */
175    for (j=0;j<len;j++)
176    {
177       if (cache[j][pulses[j]] < bits[j] && pulses[j]<MAX_PULSES-1)
178       {
179          if (out+cache[j][pulses[j]+1]-cache[j][pulses[j]] <= (total-C*esum)<<BITRES)
180          {
181             out = out+cache[j][pulses[j]+1]-cache[j][pulses[j]];
182             pulses[j] += 1;
183          }
184       }
185    }
186    while(1)
187    {
188       int incremented = 0;
189       for (j=0;j<len;j++)
190       {
191          if (pulses[j]<MAX_PULSES-1)
192          {
193             if (out+cache[j][pulses[j]+1]-cache[j][pulses[j]] <= (total-C*esum)<<BITRES)
194             {
195                out = out+cache[j][pulses[j]+1]-cache[j][pulses[j]];
196                pulses[j] += 1;
197                incremented = 1;
198             }
199          }
200       }
201       if (!incremented)
202             break;
203    }
204    RESTORE_STACK;
205    return (out+BITROUND) >> BITRES;
206 }
207
208 int compute_allocation(const CELTMode *m, int *offsets, const int *stereo_mode, int total, int *pulses, int *ebits)
209 {
210    int lo, hi, len, ret, i;
211    VARDECL(int, bits1);
212    VARDECL(int, bits2);
213    VARDECL(int, ebits1);
214    VARDECL(int, ebits2);
215    VARDECL(const celt_int16_t*, cache);
216    const int C = CHANNELS(m);
217    SAVE_STACK;
218    
219    len = m->nbEBands;
220    ALLOC(bits1, len, int);
221    ALLOC(bits2, len, int);
222    ALLOC(ebits1, len, int);
223    ALLOC(ebits2, len, int);
224    ALLOC(cache, len, const celt_int16_t*);
225    
226    if (m->nbChannels==2)
227    {
228       for (i=0;i<len;i++)
229       {
230          if (stereo_mode[i]==0)
231             cache[i] = m->bits_stereo[i];
232          else
233             cache[i] = m->bits[i];
234       }
235    } else {
236       for (i=0;i<len;i++)
237          cache[i] = m->bits[i];
238    }
239    
240    lo = 0;
241    hi = m->nbAllocVectors - 1;
242    while (hi-lo != 1)
243    {
244       int j;
245       int mid = (lo+hi) >> 1;
246       for (j=0;j<len;j++)
247       {
248          bits1[j] = (m->allocVectors[mid*len+j] + offsets[j])<<BITRES;
249          if (bits1[j] < 0)
250             bits1[j] = 0;
251          /*printf ("%d ", bits[j]);*/
252       }
253       /*printf ("\n");*/
254       if (vec_bits2pulses(m, cache, bits1, pulses, len) > (total-C*m->energy_alloc[mid*(len+1)+len])<<BITRES)
255          hi = mid;
256       else
257          lo = mid;
258       /*printf ("lo = %d, hi = %d\n", lo, hi);*/
259    }
260    /*printf ("interp between %d and %d\n", lo, hi);*/
261    {
262       int j;
263       for (j=0;j<len;j++)
264       {
265          ebits1[j] = m->energy_alloc[lo*(len+1)+j];
266          ebits2[j] = m->energy_alloc[hi*(len+1)+j];
267          bits1[j] = m->allocVectors[lo*len+j] + offsets[j];
268          bits2[j] = m->allocVectors[hi*len+j] + offsets[j];
269          if (bits1[j] < 0)
270             bits1[j] = 0;
271          if (bits2[j] < 0)
272             bits2[j] = 0;
273       }
274       ret = interp_bits2pulses(m, cache, bits1, bits2, ebits1, ebits2, total, pulses, ebits, len);
275       RESTORE_STACK;
276       return ret;
277    }
278 }
279