More stereo infrastructure
[opus.git] / libcelt / rate.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include "modes.h"
38 #include "cwrs.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 #include "entcode.h"
43 #include "rate.h"
44
45 #define BITRES 4
46 #define BITROUND 8
47 #define BITOVERFLOW 10000
48
49 #ifndef STATIC_MODES
50 #if 0
51 static int log2_frac(ec_uint32 val, int frac)
52 {
53    int i;
54    /* EC_ILOG() actually returns log2()+1, go figure */
55    int L = EC_ILOG(val)-1;
56    /*printf ("in: %d %d ", val, L);*/
57    if (L>14)
58       val >>= L-14;
59    else if (L<14)
60       val <<= 14-L;
61    L <<= frac;
62    /*printf ("%d\n", val);*/
63    for (i=0;i<frac;i++)
64    {
65       val = (val*val) >> 15;
66       /*printf ("%d\n", val);*/
67       if (val > 16384)
68          L |= (1<<(frac-i-1));
69       else   
70          val <<= 1;
71    }
72    return L;
73 }
74 #endif
75
76 static int log2_frac64(ec_uint64 val, int frac)
77 {
78    int i;
79    /* EC_ILOG64() actually returns log2()+1, go figure */
80    int L = EC_ILOG64(val)-1;
81    /*printf ("in: %d %d ", val, L);*/
82    if (L>14)
83       val >>= L-14;
84    else if (L<14)
85       val <<= 14-L;
86    L <<= frac;
87    /*printf ("%d\n", val);*/
88    for (i=0;i<frac;i++)
89    {
90       val = (val*val) >> 15;
91       /*printf ("%d\n", val);*/
92       if (val > 16384)
93          L |= (1<<(frac-i-1));
94       else   
95          val <<= 1;
96    }
97    return L;
98 }
99
100 celt_int16_t **compute_alloc_cache(CELTMode *m, int C)
101 {
102    int i, prevN;
103    celt_int16_t **bits;
104    const celt_int16_t *eBands = m->eBands;
105
106    bits = celt_alloc(m->nbEBands*sizeof(celt_int16_t*));
107    
108    C = m->nbChannels;
109    prevN = -1;
110    for (i=0;i<m->nbEBands;i++)
111    {
112       int N = C*(eBands[i+1]-eBands[i]);
113       if (N == prevN && eBands[i] < m->pitchEnd)
114       {
115          bits[i] = bits[i-1];
116       } else {
117          int j;
118          VARDECL(celt_uint64_t, u);
119          SAVE_STACK;
120          ALLOC(u, N, celt_uint64_t);
121          /* FIXME: We could save memory here */
122          bits[i] = celt_alloc(MAX_PULSES*sizeof(celt_int16_t));
123          for (j=0;j<MAX_PULSES;j++)
124          {
125             int done = 0;
126             int pulses = j;
127             /* For bands where there's no pitch, id 1 corresponds to intra prediction 
128             with no pulse. id 2 means intra prediction with one pulse, and so on.*/
129             if (eBands[i] >= m->pitchEnd)
130                pulses -= 1;
131             if (pulses < 0)
132                bits[i][j] = 0;
133             else {
134                celt_uint64_t nc;
135                nc=pulses?ncwrs_unext64(N, u):ncwrs_u64(N, 0, u);
136                bits[i][j] = log2_frac64(nc,BITRES);
137                /* FIXME: Could there be a better test for the max number of pulses that fit in 64 bits? */
138                if (bits[i][j] > (60<<BITRES))
139                   done = 1;
140                /* Add the intra-frame prediction sign bit */
141                if (eBands[i] >= m->pitchEnd)
142                   bits[i][j] += (1<<BITRES);
143             }
144             if (done)
145                break;
146          }
147          for (;j<MAX_PULSES;j++)
148             bits[i][j] = BITOVERFLOW;
149          prevN = N;
150          RESTORE_STACK;
151       }
152    }
153    return bits;
154 }
155
156 #endif /* !STATIC_MODES */
157
158 static inline int bits2pulses(const CELTMode *m, const celt_int16_t *cache, int bits)
159 {
160    int i;
161    int lo, hi;
162    lo = 0;
163    hi = MAX_PULSES-1;
164    
165    /* Instead of using the "bisection condition" we use a fixed number of 
166       iterations because it should be faster */
167    /*while (hi-lo != 1)*/
168    for (i=0;i<LOG_MAX_PULSES;i++)
169    {
170       int mid = (lo+hi)>>1;
171       /* OPT: Make sure this is implemented with a conditional move */
172       if (cache[mid] >= bits)
173          hi = mid;
174       else
175          lo = mid;
176    }
177    if (bits-cache[lo] <= cache[hi]-bits)
178       return lo;
179    else
180       return hi;
181 }
182
183 static int vec_bits2pulses(const CELTMode *m, const celt_int16_t * const *cache, int *bits, int *pulses, int len)
184 {
185    int i;
186    int sum=0;
187
188    for (i=0;i<len;i++)
189    {
190       pulses[i] = bits2pulses(m, cache[i], bits[i]);
191       sum += cache[i][pulses[i]];
192    }
193    /*printf ("sum = %d\n", sum);*/
194    return sum;
195 }
196
197 static int interp_bits2pulses(const CELTMode *m, const celt_int16_t * const *cache, int *bits1, int *bits2, int total, int *pulses, int len)
198 {
199    int lo, hi, out;
200    int j;
201    VARDECL(int, bits);
202    SAVE_STACK;
203    ALLOC(bits, len, int);
204    lo = 0;
205    hi = 1<<BITRES;
206    while (hi-lo != 1)
207    {
208       int mid = (lo+hi)>>1;
209       for (j=0;j<len;j++)
210          bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
211       if (vec_bits2pulses(m, cache, bits, pulses, len) > total<<BITRES)
212          hi = mid;
213       else
214          lo = mid;
215    }
216    /*printf ("interp bisection gave %d\n", lo);*/
217    for (j=0;j<len;j++)
218       bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
219    out = vec_bits2pulses(m, cache, bits, pulses, len);
220    /* Do some refinement to use up all bits. In the first pass, we can only add pulses to 
221       bands that are under their allocated budget. In the second pass, anything goes */
222    for (j=0;j<len;j++)
223    {
224       if (cache[j][pulses[j]] < bits[j] && pulses[j]<MAX_PULSES-1)
225       {
226          if (out+cache[j][pulses[j]+1]-cache[j][pulses[j]] <= total<<BITRES)
227          {
228             out = out+cache[j][pulses[j]+1]-cache[j][pulses[j]];
229             pulses[j] += 1;
230          }
231       }
232    }
233    while(1)
234    {
235       int incremented = 0;
236       for (j=0;j<len;j++)
237       {
238          if (pulses[j]<MAX_PULSES-1)
239          {
240             if (out+cache[j][pulses[j]+1]-cache[j][pulses[j]] <= total<<BITRES)
241             {
242                out = out+cache[j][pulses[j]+1]-cache[j][pulses[j]];
243                pulses[j] += 1;
244                incremented = 1;
245             }
246          }
247       }
248       if (!incremented)
249             break;
250    }
251    RESTORE_STACK;
252    return (out+BITROUND) >> BITRES;
253 }
254
255 int compute_allocation(const CELTMode *m, int *offsets, const int *stereo_mode, int total, int *pulses)
256 {
257    int lo, hi, len, ret, i;
258    VARDECL(int, bits1);
259    VARDECL(int, bits2);
260    VARDECL(const celt_int16_t*, cache);
261    SAVE_STACK;
262    
263    len = m->nbEBands;
264    ALLOC(bits1, len, int);
265    ALLOC(bits2, len, int);
266    ALLOC(cache, len, const celt_int16_t*);
267    
268    if (m->nbChannels==2)
269    {
270       for (i=0;i<len;i++)
271       {
272          if (stereo_mode[i]==0)
273             cache[i] = m->bits[i];
274       }
275    } else {
276       for (i=0;i<len;i++)
277          cache[i] = m->bits[i];
278    }
279    
280    lo = 0;
281    hi = m->nbAllocVectors - 1;
282    while (hi-lo != 1)
283    {
284       int j;
285       int mid = (lo+hi) >> 1;
286       for (j=0;j<len;j++)
287       {
288          bits1[j] = (m->allocVectors[mid*len+j] + offsets[j])<<BITRES;
289          if (bits1[j] < 0)
290             bits1[j] = 0;
291          /*printf ("%d ", bits[j]);*/
292       }
293       /*printf ("\n");*/
294       if (vec_bits2pulses(m, cache, bits1, pulses, len) > total<<BITRES)
295          hi = mid;
296       else
297          lo = mid;
298       /*printf ("lo = %d, hi = %d\n", lo, hi);*/
299    }
300    {
301       int j;
302       for (j=0;j<len;j++)
303       {
304          bits1[j] = m->allocVectors[lo*len+j] + offsets[j];
305          bits2[j] = m->allocVectors[hi*len+j] + offsets[j];
306          if (bits1[j] < 0)
307             bits1[j] = 0;
308          if (bits2[j] < 0)
309             bits2[j] = 0;
310       }
311       ret = interp_bits2pulses(m, cache, bits1, bits2, total, pulses, len);
312       RESTORE_STACK;
313       return ret;
314    }
315 }
316