Looks like the bit allocation code is mostly working. Just need to actually
[opus.git] / libcelt / rate.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include "modes.h"
34 #include "cwrs.h"
35 #include "arch.h"
36 #include "os_support.h"
37
38 #include "entcode.h"
39 #include "rate.h"
40
41 #define BITRES 4
42 #define BITROUND 8
43 #define BITOVERFLOW 10000
44
45 #define MAX_PULSES 64
46
47 int log2_frac(ec_uint32 val, int frac)
48 {
49    int i;
50    /* EC_ILOG() actually returns log2()+1, go figure */
51    int L = EC_ILOG(val)-1;
52    //printf ("in: %d %d ", val, L);
53    if (L>14)
54       val >>= L-14;
55    else if (L<14)
56       val <<= 14-L;
57    L <<= frac;
58    //printf ("%d\n", val);
59    for (i=0;i<frac;i++)
60    {
61       val = (val*val) >> 15;
62       //printf ("%d\n", val);
63       if (val > 16384)
64          L |= (1<<(frac-i-1));
65       else   
66          val <<= 1;
67    }
68    return L;
69 }
70
71 int log2_frac64(ec_uint64 val, int frac)
72 {
73    int i;
74    /* EC_ILOG64() actually returns log2()+1, go figure */
75    int L = EC_ILOG64(val)-1;
76    //printf ("in: %d %d ", val, L);
77    if (L>14)
78       val >>= L-14;
79    else if (L<14)
80       val <<= 14-L;
81    L <<= frac;
82    //printf ("%d\n", val);
83    for (i=0;i<frac;i++)
84    {
85       val = (val*val) >> 15;
86       //printf ("%d\n", val);
87       if (val > 16384)
88          L |= (1<<(frac-i-1));
89       else   
90          val <<= 1;
91    }
92    return L;
93 }
94
95
96 void alloc_init(struct alloc_data *alloc, const CELTMode *m)
97 {
98    int i, prevN, BC;
99    const int *eBands = m->eBands;
100    
101    alloc->mode = m;
102    alloc->len = m->nbEBands;
103    alloc->bands = m->eBands;
104    alloc->bits = celt_alloc(m->nbEBands*sizeof(int*));
105    
106    BC = m->nbMdctBlocks*m->nbChannels;
107    prevN = -1;
108    for (i=0;i<alloc->len;i++)
109    {
110       int N = BC*(eBands[i+1]-eBands[i]);
111       if (N == prevN)
112       {
113          alloc->bits[i] = alloc->bits[i-1];
114       } else {
115          int j;
116          /* FIXME: We could save memory here */
117          alloc->bits[i] = celt_alloc(MAX_PULSES*sizeof(int));
118          for (j=0;j<MAX_PULSES;j++)
119          {
120             alloc->bits[i][j] = log2_frac64(ncwrs64(N, j),BITRES);
121             /* We could just update rev_bits here */
122             if (alloc->bits[i][j] > (60<<BITRES))
123                break;
124          }
125          for (;j<MAX_PULSES;j++)
126             alloc->bits[i][j] = BITOVERFLOW;
127          prevN = N;
128       }
129    }
130 }
131
132 void alloc_clear(struct alloc_data *alloc)
133 {
134    int i;
135    int *prevPtr = NULL;
136    for (i=0;i<alloc->len;i++)
137    {
138       if (alloc->bits[i] != prevPtr)
139       {
140          prevPtr = alloc->bits[i];
141          celt_free(alloc->bits[i]);
142       }
143    }
144    celt_free(alloc->bits);
145 }
146
147 int bits2pulses(const struct alloc_data *alloc, int band, int bits)
148 {
149    int lo, hi;
150    lo = 0;
151    hi = MAX_PULSES-1;
152    
153    while (hi-lo != 1)
154    {
155       int mid = (lo+hi)>>1;
156       if (alloc->bits[band][mid] >= bits)
157          hi = mid;
158       else
159          lo = mid;
160    }
161    if (bits-alloc->bits[band][lo] <= alloc->bits[band][hi]-bits)
162       return lo;
163    else
164       return hi;
165 }
166
167 int vec_bits2pulses(const struct alloc_data *alloc, const int *bands, int *bits, int *pulses, int len)
168 {
169    int i, BC;
170    int sum=0;
171    BC = alloc->mode->nbMdctBlocks*alloc->mode->nbChannels;
172
173    for (i=0;i<len;i++)
174    {
175       int N = (bands[i+1]-bands[i])*BC;
176       pulses[i] = bits2pulses(alloc, i, bits[i]);
177       sum += alloc->bits[i][pulses[i]];
178    }
179    //printf ("sum = %d\n", sum);
180    return sum;
181 }
182
183 int interp_bits2pulses(const struct alloc_data *alloc, int *bits1, int *bits2, int total, int *pulses, int len)
184 {
185    int lo, hi, out;
186    int j;
187    int bits[len];
188    int used_bits[len];
189    const int *bands = alloc->bands;
190    lo = 0;
191    hi = 1<<BITRES;
192    while (hi-lo != 1)
193    {
194       int mid = (lo+hi)>>1;
195       for (j=0;j<len;j++)
196          bits[j] = ((1<<BITRES)-mid)*bits1[j] + mid*bits2[j];
197       if (vec_bits2pulses(alloc, bands, bits, pulses, len) > total<<BITRES)
198          hi = mid;
199       else
200          lo = mid;
201    }
202    //printf ("interp bisection gave %d\n", lo);
203    for (j=0;j<len;j++)
204       bits[j] = ((1<<BITRES)-lo)*bits1[j] + lo*bits2[j];
205    out = vec_bits2pulses(alloc, bands, bits, pulses, len);
206    /* Do some refinement to use up all bits */
207    while(1)
208    {
209       int incremented = 0;
210       for (j=0;j<len;j++)
211       {
212          if (alloc->bits[j][pulses[j]] < bits[j] && pulses[j]<MAX_PULSES-1)
213          {
214             if (out+alloc->bits[j][pulses[j]+1]-alloc->bits[j][pulses[j]] <= total<<BITRES)
215             {
216                out = out+alloc->bits[j][pulses[j]+1]-alloc->bits[j][pulses[j]];
217                pulses[j] += 1;
218                incremented = 1;
219                //printf ("INCREMENT %d\n", j);
220             }
221          }
222       }
223       if (!incremented)
224          break;
225    }
226    return (out+BITROUND) >> BITRES;
227 }
228
229 int compute_allocation(const struct alloc_data *alloc, int *offsets, int total, int *pulses)
230 {
231    int lo, hi, len;
232    const CELTMode *m;
233
234    m = alloc->mode;
235    len = m->nbEBands;
236    lo = 0;
237    hi = m->nbAllocVectors - 1;
238    while (hi-lo != 1)
239    {
240       int j;
241       int bits[len];
242       int pulses[len];
243       int mid = (lo+hi) >> 1;
244       for (j=0;j<len;j++)
245       {
246          bits[j] = (m->allocVectors[mid*len+j] + offsets[j])<<BITRES;
247          if (bits[j] < 0)
248             bits[j] = 0;
249          //printf ("%d ", bits[j]);
250       }
251       //printf ("\n");
252       if (vec_bits2pulses(alloc, alloc->bands, bits, pulses, len) > total<<BITRES)
253          hi = mid;
254       else
255          lo = mid;
256       //printf ("lo = %d, hi = %d\n", lo, hi);
257    }
258    {
259       int bits1[len];
260       int bits2[len];
261       int j;
262       for (j=0;j<len;j++)
263       {
264          bits1[j] = m->allocVectors[lo*len+j] + offsets[j];
265          bits2[j] = m->allocVectors[hi*len+j] + offsets[j];
266          if (bits1[j] < 0)
267             bits1[j] = 0;
268          if (bits2[j] < 0)
269             bits2[j] = 0;
270       }
271       return interp_bits2pulses(alloc, bits1, bits2, total, pulses, len);
272    }
273 }
274
275 #if 0
276 int main()
277 {
278    int i;
279    printf ("log(128) = %d\n", EC_ILOG(128));
280    for(i=1;i<2000000000;i+=1738)
281    {
282       printf ("%d %d\n", i, log2_frac(i, 10));
283    }
284    return 0;
285 }
286 #endif
287 #if 0
288 int main()
289 {
290    int i;
291    int offsets[18] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
292    int bits[18] = {10, 9, 9, 8, 8, 8, 8, 8, 8, 8, 9, 10, 8, 9, 10, 11, 6, 7};
293    int bits1[18] = {8, 7, 7, 6, 6, 6, 5, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5};
294    int bits2[18] = {15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15};
295    int bank[20] = {0,  4,  8, 12, 16, 20, 24, 28, 32, 38, 44, 52, 62, 74, 90,112,142,182, 232,256};
296    int pulses[18];
297    struct alloc_data alloc;
298    
299    alloc_init(&alloc, celt_mode0);
300    int b;
301    //b = vec_bits2pulses(&alloc, bank, bits, pulses, 18);
302    //printf ("total: %d bits\n", b);
303    //for (i=0;i<18;i++)
304    //   printf ("%d ", pulses[i]);
305    //printf ("\n");
306    //b = interp_bits2pulses(&alloc, bits1, bits2, 162, pulses, 18);
307    b = compute_allocation(&alloc, offsets, 190, pulses);
308    printf ("total: %d bits\n", b);
309    for (i=0;i<18;i++)
310       printf ("%d ", pulses[i]);
311    printf ("\n");
312
313    alloc_clear(&alloc);
314    return 0;
315 }
316 #endif