Re-enabled intra-frame prediction, which seems to have exposed a few issues
[opus.git] / libcelt / rate.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include "modes.h"
34 #include "cwrs.h"
35 #include "arch.h"
36 #include "os_support.h"
37
38 #include "entcode.h"
39 #include "rate.h"
40
41 #define BITRES 4
42 #define BITROUND 8
43 #define BITOVERFLOW 10000
44
45 #define MAX_PULSES 64
46
47 int log2_frac(ec_uint32 val, int frac)
48 {
49    int i;
50    /* EC_ILOG() actually returns log2()+1, go figure */
51    int L = EC_ILOG(val)-1;
52    //printf ("in: %d %d ", val, L);
53    if (L>14)
54       val >>= L-14;
55    else if (L<14)
56       val <<= 14-L;
57    L <<= frac;
58    //printf ("%d\n", val);
59    for (i=0;i<frac;i++)
60    {
61       val = (val*val) >> 15;
62       //printf ("%d\n", val);
63       if (val > 16384)
64          L |= (1<<(frac-i-1));
65       else   
66          val <<= 1;
67    }
68    return L;
69 }
70
71 int log2_frac64(ec_uint64 val, int frac)
72 {
73    int i;
74    /* EC_ILOG64() actually returns log2()+1, go figure */
75    int L = EC_ILOG64(val)-1;
76    //printf ("in: %d %d ", val, L);
77    if (L>14)
78       val >>= L-14;
79    else if (L<14)
80       val <<= 14-L;
81    L <<= frac;
82    //printf ("%d\n", val);
83    for (i=0;i<frac;i++)
84    {
85       val = (val*val) >> 15;
86       //printf ("%d\n", val);
87       if (val > 16384)
88          L |= (1<<(frac-i-1));
89       else   
90          val <<= 1;
91    }
92    return L;
93 }
94
95
96 void alloc_init(struct alloc_data *alloc, const CELTMode *m)
97 {
98    int i, prevN, BC;
99    const int *eBands = m->eBands;
100    
101    alloc->mode = m;
102    alloc->len = m->nbEBands;
103    alloc->bands = m->eBands;
104    alloc->bits = celt_alloc(m->nbEBands*sizeof(int*));
105    
106    BC = m->nbMdctBlocks*m->nbChannels;
107    prevN = -1;
108    for (i=0;i<alloc->len;i++)
109    {
110       int N = BC*(eBands[i+1]-eBands[i]);
111       if (N == prevN && eBands[i] < m->pitchEnd)
112       {
113          alloc->bits[i] = alloc->bits[i-1];
114       } else {
115          int j;
116          /* FIXME: We could save memory here */
117          alloc->bits[i] = celt_alloc(MAX_PULSES*sizeof(int));
118          for (j=0;j<MAX_PULSES;j++)
119          {
120             int done = 0;
121             alloc->bits[i][j] = log2_frac64(ncwrs64(N, j),BITRES);
122             /* FIXME: Could there be a better test for the max number of pulses that fit in 64 bits? */
123             if (alloc->bits[i][j] > (60<<BITRES))
124                done = 1;
125             /* Add the intra-frame prediction bits */
126             if (eBands[i] >= m->pitchEnd)
127                alloc->bits[i][j] += (1<<BITRES) + log2_frac64(2*eBands[i]-eBands[i+1],BITRES);
128             /* We could just update rev_bits here */
129             if (done)
130                break;
131          }
132          for (;j<MAX_PULSES;j++)
133             alloc->bits[i][j] = BITOVERFLOW;
134          prevN = N;
135       }
136    }
137 }
138
139 void alloc_clear(struct alloc_data *alloc)
140 {
141    int i;
142    int *prevPtr = NULL;
143    for (i=0;i<alloc->len;i++)
144    {
145       if (alloc->bits[i] != prevPtr)
146       {
147          prevPtr = alloc->bits[i];
148          celt_free(alloc->bits[i]);
149       }
150    }
151    celt_free(alloc->bits);
152 }
153
154 int bits2pulses(const struct alloc_data *alloc, int band, int bits)
155 {
156    int lo, hi;
157    lo = 0;
158    hi = MAX_PULSES-1;
159    
160    while (hi-lo != 1)
161    {
162       int mid = (lo+hi)>>1;
163       if (alloc->bits[band][mid] >= bits)
164          hi = mid;
165       else
166          lo = mid;
167    }
168    if (bits-alloc->bits[band][lo] <= alloc->bits[band][hi]-bits)
169       return lo;
170    else
171       return hi;
172 }
173
174 int vec_bits2pulses(const struct alloc_data *alloc, const int *bands, int *bits, int *pulses, int len)
175 {
176    int i, BC;
177    int sum=0;
178    BC = alloc->mode->nbMdctBlocks*alloc->mode->nbChannels;
179
180    for (i=0;i<len;i++)
181    {
182       int N = (bands[i+1]-bands[i])*BC;
183       pulses[i] = bits2pulses(alloc, i, bits[i]);
184       sum += alloc->bits[i][pulses[i]];
185    }
186    //printf ("sum = %d\n", sum);
187    return sum;
188 }
189
190 int interp_bits2pulses(const struct alloc_data *alloc, int *bits1, int *bits2, int total, int *pulses, int len)
191 {
192    int lo, hi, out;
193    int j;
194    int bits[len];
195    const int *bands = alloc->bands;
196    lo = 0;
197    hi = 1<<BITRES;
198    while (hi-lo != 1)
199    {
200       int mid = (lo+hi)>>1;
201       for (j=0;j<len;j++)
202          bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
203       if (vec_bits2pulses(alloc, bands, bits, pulses, len) > total<<BITRES)
204          hi = mid;
205       else
206          lo = mid;
207    }
208    //printf ("interp bisection gave %d\n", lo);
209    for (j=0;j<len;j++)
210       bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
211    out = vec_bits2pulses(alloc, bands, bits, pulses, len);
212    /* Do some refinement to use up all bits */
213    while(1)
214    {
215       int incremented = 0;
216       for (j=0;j<len;j++)
217       {
218          if (alloc->bits[j][pulses[j]] < bits[j] && pulses[j]<MAX_PULSES-1)
219          {
220             if (out+alloc->bits[j][pulses[j]+1]-alloc->bits[j][pulses[j]] <= total<<BITRES)
221             {
222                out = out+alloc->bits[j][pulses[j]+1]-alloc->bits[j][pulses[j]];
223                pulses[j] += 1;
224                incremented = 1;
225                //printf ("INCREMENT %d\n", j);
226             }
227          }
228       }
229       if (!incremented)
230          break;
231    }
232    return (out+BITROUND) >> BITRES;
233 }
234
235 int compute_allocation(const struct alloc_data *alloc, int *offsets, int total, int *pulses)
236 {
237    int lo, hi, len;
238    const CELTMode *m;
239
240    m = alloc->mode;
241    len = m->nbEBands;
242    lo = 0;
243    hi = m->nbAllocVectors - 1;
244    while (hi-lo != 1)
245    {
246       int j;
247       int bits[len];
248       int pulses[len];
249       int mid = (lo+hi) >> 1;
250       for (j=0;j<len;j++)
251       {
252          bits[j] = (m->allocVectors[mid*len+j] + offsets[j])<<BITRES;
253          if (bits[j] < 0)
254             bits[j] = 0;
255          //printf ("%d ", bits[j]);
256       }
257       //printf ("\n");
258       if (vec_bits2pulses(alloc, alloc->bands, bits, pulses, len) > total<<BITRES)
259          hi = mid;
260       else
261          lo = mid;
262       //printf ("lo = %d, hi = %d\n", lo, hi);
263    }
264    {
265       int bits1[len];
266       int bits2[len];
267       int j;
268       for (j=0;j<len;j++)
269       {
270          bits1[j] = m->allocVectors[lo*len+j] + offsets[j];
271          bits2[j] = m->allocVectors[hi*len+j] + offsets[j];
272          if (bits1[j] < 0)
273             bits1[j] = 0;
274          if (bits2[j] < 0)
275             bits2[j] = 0;
276       }
277       return interp_bits2pulses(alloc, bits1, bits2, total, pulses, len);
278    }
279 }
280
281 #if 0
282 int main()
283 {
284    int i;
285    printf ("log(128) = %d\n", EC_ILOG(128));
286    for(i=1;i<2000000000;i+=1738)
287    {
288       printf ("%d %d\n", i, log2_frac(i, 10));
289    }
290    return 0;
291 }
292 #endif
293 #if 0
294 int main()
295 {
296    int i;
297    int offsets[18] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
298    int bits[18] = {10, 9, 9, 8, 8, 8, 8, 8, 8, 8, 9, 10, 8, 9, 10, 11, 6, 7};
299    int bits1[18] = {8, 7, 7, 6, 6, 6, 5, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5};
300    int bits2[18] = {15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15};
301    int bank[20] = {0,  4,  8, 12, 16, 20, 24, 28, 32, 38, 44, 52, 62, 74, 90,112,142,182, 232,256};
302    int pulses[18];
303    struct alloc_data alloc;
304    
305    alloc_init(&alloc, celt_mode0);
306    int b;
307    //b = vec_bits2pulses(&alloc, bank, bits, pulses, 18);
308    //printf ("total: %d bits\n", b);
309    //for (i=0;i<18;i++)
310    //   printf ("%d ", pulses[i]);
311    //printf ("\n");
312    //b = interp_bits2pulses(&alloc, bits1, bits2, 162, pulses, 18);
313    b = compute_allocation(&alloc, offsets, 190, pulses);
314    printf ("total: %d bits\n", b);
315    for (i=0;i<18;i++)
316       printf ("%d ", pulses[i]);
317    printf ("\n");
318
319    alloc_clear(&alloc);
320    return 0;
321 }
322 #endif