separated the two passes from interp_bits2pulses()
[opus.git] / libcelt / rate.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include "modes.h"
38 #include "cwrs.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 #include "entcode.h"
43 #include "rate.h"
44
45 #define BITRES 4
46 #define BITROUND 8
47 #define BITOVERFLOW 10000
48
49 #ifndef STATIC_MODES
50 static int log2_frac(ec_uint32 val, int frac)
51 {
52    int i;
53    /* EC_ILOG() actually returns log2()+1, go figure */
54    int L = EC_ILOG(val)-1;
55    /*printf ("in: %d %d ", val, L);*/
56    if (L>14)
57       val >>= L-14;
58    else if (L<14)
59       val <<= 14-L;
60    L <<= frac;
61    /*printf ("%d\n", val);*/
62    for (i=0;i<frac;i++)
63    {
64       val = (val*val) >> 15;
65       /*printf ("%d\n", val);*/
66       if (val > 16384)
67          L |= (1<<(frac-i-1));
68       else   
69          val <<= 1;
70    }
71    return L;
72 }
73
74 static int log2_frac64(ec_uint64 val, int frac)
75 {
76    int i;
77    /* EC_ILOG64() actually returns log2()+1, go figure */
78    int L = EC_ILOG64(val)-1;
79    /*printf ("in: %d %d ", val, L);*/
80    if (L>14)
81       val >>= L-14;
82    else if (L<14)
83       val <<= 14-L;
84    L <<= frac;
85    /*printf ("%d\n", val);*/
86    for (i=0;i<frac;i++)
87    {
88       val = (val*val) >> 15;
89       /*printf ("%d\n", val);*/
90       if (val > 16384)
91          L |= (1<<(frac-i-1));
92       else   
93          val <<= 1;
94    }
95    return L;
96 }
97
98 void compute_alloc_cache(CELTMode *m)
99 {
100    int i, prevN, BC;
101    celt_int16_t **bits;
102    const celt_int16_t *eBands = m->eBands;
103
104    bits = celt_alloc(m->nbEBands*sizeof(celt_int16_t*));
105    
106    BC = m->nbChannels;
107    prevN = -1;
108    for (i=0;i<m->nbEBands;i++)
109    {
110       int N = BC*(eBands[i+1]-eBands[i]);
111       if (N == prevN && eBands[i] < m->pitchEnd)
112       {
113          bits[i] = bits[i-1];
114       } else {
115          int j;
116          VARDECL(celt_uint64_t, u);
117          SAVE_STACK;
118          ALLOC(u, N, celt_uint64_t);
119          /* FIXME: We could save memory here */
120          bits[i] = celt_alloc(MAX_PULSES*sizeof(celt_int16_t));
121          for (j=0;j<MAX_PULSES;j++)
122          {
123             int done = 0;
124             int pulses = j;
125             /* For bands where there's no pitch, id 1 corresponds to intra prediction 
126             with no pulse. id 2 means intra prediction with one pulse, and so on.*/
127             if (eBands[i] >= m->pitchEnd)
128                pulses -= 1;
129             if (pulses < 0)
130                bits[i][j] = 0;
131             else {
132                celt_uint64_t nc;
133                nc=pulses?ncwrs_unext64(N, u):ncwrs_u64(N, 0, u);
134                bits[i][j] = log2_frac64(nc,BITRES);
135                /* FIXME: Could there be a better test for the max number of pulses that fit in 64 bits? */
136                if (bits[i][j] > (60<<BITRES))
137                   done = 1;
138                /* Add the intra-frame prediction bits */
139                if (eBands[i] >= m->pitchEnd)
140                {
141                   int max_pos = 2*eBands[i]-eBands[i+1];
142                   if (max_pos > 32)
143                      max_pos = 32;
144                   bits[i][j] += (1<<BITRES) + log2_frac(max_pos,BITRES);
145                }
146             }
147             if (done)
148                break;
149          }
150          for (;j<MAX_PULSES;j++)
151             bits[i][j] = BITOVERFLOW;
152          prevN = N;
153          RESTORE_STACK;
154       }
155    }
156    m->bits = (const celt_int16_t * const *)bits;
157 }
158
159 #endif /* !STATIC_MODES */
160
161 static inline int bits2pulses(const CELTMode *m, int band, int bits)
162 {
163    int i;
164    int lo, hi;
165    lo = 0;
166    hi = MAX_PULSES-1;
167    
168    /* Instead of using the "bisection confition" we use a fixed number of 
169       iterations because it should be faster */
170    /*while (hi-lo != 1)*/
171    for (i=0;i<LOG_MAX_PULSES;i++)
172    {
173       int mid = (lo+hi)>>1;
174       /* OPT: Make sure this is implemented with a conditional move */
175       if (m->bits[band][mid] >= bits)
176          hi = mid;
177       else
178          lo = mid;
179    }
180    if (bits-m->bits[band][lo] <= m->bits[band][hi]-bits)
181       return lo;
182    else
183       return hi;
184 }
185
186 static int vec_bits2pulses(const CELTMode *m, int *bits, int *pulses, int len)
187 {
188    int i;
189    int sum=0;
190
191    for (i=0;i<len;i++)
192    {
193       pulses[i] = bits2pulses(m, i, bits[i]);
194       sum += m->bits[i][pulses[i]];
195    }
196    /*printf ("sum = %d\n", sum);*/
197    return sum;
198 }
199
200 static int interp_bits2pulses(const CELTMode *m, int *bits1, int *bits2, int total, int *pulses, int len)
201 {
202    int lo, hi, out;
203    int j;
204    VARDECL(int, bits);
205    SAVE_STACK;
206    ALLOC(bits, len, int);
207    lo = 0;
208    hi = 1<<BITRES;
209    while (hi-lo != 1)
210    {
211       int mid = (lo+hi)>>1;
212       for (j=0;j<len;j++)
213          bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
214       if (vec_bits2pulses(m, bits, pulses, len) > total<<BITRES)
215          hi = mid;
216       else
217          lo = mid;
218    }
219    /*printf ("interp bisection gave %d\n", lo);*/
220    for (j=0;j<len;j++)
221       bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
222    out = vec_bits2pulses(m, bits, pulses, len);
223    /* Do some refinement to use up all bits. In the first pass, we can only add pulses to 
224       bands that are under their allocated budget. In the second pass, anything goes */
225    for (j=0;j<len;j++)
226    {
227       if (m->bits[j][pulses[j]] < bits[j] && pulses[j]<MAX_PULSES-1)
228       {
229          if (out+m->bits[j][pulses[j]+1]-m->bits[j][pulses[j]] <= total<<BITRES)
230          {
231             out = out+m->bits[j][pulses[j]+1]-m->bits[j][pulses[j]];
232             pulses[j] += 1;
233          }
234       }
235    }
236    while(1)
237    {
238       int incremented = 0;
239       for (j=0;j<len;j++)
240       {
241          if (pulses[j]<MAX_PULSES-1)
242          {
243             if (out+m->bits[j][pulses[j]+1]-m->bits[j][pulses[j]] <= total<<BITRES)
244             {
245                out = out+m->bits[j][pulses[j]+1]-m->bits[j][pulses[j]];
246                pulses[j] += 1;
247                incremented = 1;
248             }
249          }
250       }
251       if (!incremented)
252             break;
253    }
254    RESTORE_STACK;
255    return (out+BITROUND) >> BITRES;
256 }
257
258 int compute_allocation(const CELTMode *m, int *offsets, int total, int *pulses)
259 {
260    int lo, hi, len, ret;
261    VARDECL(int, bits1);
262    VARDECL(int, bits2);
263    SAVE_STACK;
264    
265    len = m->nbEBands;
266    ALLOC(bits1, len, int);
267    ALLOC(bits2, len, int);
268    lo = 0;
269    hi = m->nbAllocVectors - 1;
270    while (hi-lo != 1)
271    {
272       int j;
273       int mid = (lo+hi) >> 1;
274       for (j=0;j<len;j++)
275       {
276          bits1[j] = (m->allocVectors[mid*len+j] + offsets[j])<<BITRES;
277          if (bits1[j] < 0)
278             bits1[j] = 0;
279          /*printf ("%d ", bits[j]);*/
280       }
281       /*printf ("\n");*/
282       if (vec_bits2pulses(m, bits1, pulses, len) > total<<BITRES)
283          hi = mid;
284       else
285          lo = mid;
286       /*printf ("lo = %d, hi = %d\n", lo, hi);*/
287    }
288    {
289       int j;
290       for (j=0;j<len;j++)
291       {
292          bits1[j] = m->allocVectors[lo*len+j] + offsets[j];
293          bits2[j] = m->allocVectors[hi*len+j] + offsets[j];
294          if (bits1[j] < 0)
295             bits1[j] = 0;
296          if (bits2[j] < 0)
297             bits2[j] = 0;
298       }
299       ret = interp_bits2pulses(m, bits1, bits2, total, pulses, len);
300       RESTORE_STACK;
301       return ret;
302    }
303 }
304