Fixes rare overflow in intra_decision()
[opus.git] / libcelt / quant_bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "quant_bands.h"
38 #include "laplace.h"
39 #include <math.h>
40 #include "os_support.h"
41 #include "arch.h"
42 #include "mathops.h"
43 #include "stack_alloc.h"
44
45 #ifdef FIXED_POINT
46 /* Mean energy in each band quantized in Q6 */
47 const signed char eMeans[25] = {
48      124,122,115,106,100,
49       95, 91, 90, 99, 96,
50       94, 93, 98, 91, 86,
51       90, 88, 88, 90, 85,
52       64, 64, 64, 64, 64};
53 #else
54 /* Mean energy in each band quantized in Q6 and converted back to float */
55 const celt_word16 eMeans[25] = {
56       7.750000f, 7.625000f, 7.187500f, 6.625000f, 6.250000f,
57       5.937500f, 5.687500f, 5.625000f, 6.187500f, 6.000000f,
58       5.875000f, 5.812500f, 6.125000f, 5.687500f, 5.375000f,
59       5.625000f, 5.500000f, 5.500000f, 5.625000f, 5.312500f,
60       4.000000f, 4.000000f, 4.000000f, 4.000000f, 4.000000f};
61 #endif
62 /* prediction coefficients: 0.9, 0.8, 0.65, 0.5 */
63 #ifdef FIXED_POINT
64 static const celt_word16 pred_coef[4] = {29440, 26112, 21248, 16384};
65 #else
66 static const celt_word16 pred_coef[4] = {29440/32768., 26112/32768., 21248/32768., 16384/32768.};
67 #endif
68
69 static int intra_decision(const celt_word16 *eBands, celt_word16 *oldEBands, int start, int end, int len, int C)
70 {
71    int c, i;
72    celt_word32 dist = 0;
73    for (c=0;c<C;c++)
74    {
75       for (i=start;i<end;i++)
76       {
77          celt_word16 d = SHR16(SUB16(eBands[i+c*len], oldEBands[i+c*len]),1);
78          dist = MAC16_16(dist, d,d);
79       }
80    }
81    return SHR32(dist,2*DB_SHIFT-2) > 2*C*(end-start);
82 }
83
84 #ifndef STATIC_MODES
85
86 celt_int16 *quant_prob_alloc(const CELTMode *m)
87 {
88    int i;
89    celt_int16 *prob;
90    prob = celt_alloc(4*m->nbEBands*sizeof(celt_int16));
91    if (prob==NULL)
92      return NULL;
93    for (i=0;i<m->nbEBands;i++)
94    {
95       prob[2*i] = 7000-i*200;
96       prob[2*i+1] = ec_laplace_get_start_freq(prob[2*i]);
97    }
98    for (i=0;i<m->nbEBands;i++)
99    {
100       prob[2*m->nbEBands+2*i] = 9000-i*220;
101       prob[2*m->nbEBands+2*i+1] = ec_laplace_get_start_freq(prob[2*m->nbEBands+2*i]);
102    }
103    return prob;
104 }
105
106 void quant_prob_free(const celt_int16 *freq)
107 {
108    celt_free((celt_int16*)freq);
109 }
110 #endif
111
112 static void quant_coarse_energy_impl(const CELTMode *m, int start, int end,
113       const celt_word16 *eBands, celt_word16 *oldEBands, int budget,
114       const celt_int16 *prob, celt_word16 *error, ec_enc *enc, int _C, int LM,
115       int intra, celt_word16 max_decay)
116 {
117    const int C = CHANNELS(_C);
118    int i, c;
119    celt_word32 prev[2] = {0,0};
120    celt_word16 coef;
121    celt_word16 beta;
122
123    coef = pred_coef[LM];
124
125    ec_enc_bit_prob(enc, intra, 8192);
126    if (intra)
127    {
128       coef = 0;
129       prob += 2*m->nbEBands;
130    }
131    /* No theoretical justification for this, it just works */
132    beta = MULT16_16_P15(coef,coef);
133    /* Encode at a fixed coarse resolution */
134    for (i=start;i<end;i++)
135    {
136       c=0;
137       do {
138          int bits_left;
139          int qi;
140          celt_word16 q;
141          celt_word16 x;
142          celt_word32 f;
143          x = eBands[i+c*m->nbEBands];
144 #ifdef FIXED_POINT
145          f = SHL32(EXTEND32(x),15) -MULT16_16(coef,oldEBands[i+c*m->nbEBands])-prev[c];
146          /* Rounding to nearest integer here is really important! */
147          qi = (f+QCONST32(.5,DB_SHIFT+15))>>(DB_SHIFT+15);
148 #else
149          f = x-coef*oldEBands[i+c*m->nbEBands]-prev[c];
150          /* Rounding to nearest integer here is really important! */
151          qi = (int)floor(.5f+f);
152 #endif
153          /* Prevent the energy from going down too quickly (e.g. for bands
154             that have just one bin) */
155          if (qi < 0 && x < oldEBands[i+c*m->nbEBands]-max_decay)
156          {
157             qi += (int)SHR16(oldEBands[i+c*m->nbEBands]-max_decay-x, DB_SHIFT);
158             if (qi > 0)
159                qi = 0;
160          }
161          /* If we don't have enough bits to encode all the energy, just assume something safe.
162             We allow slightly busting the budget here */
163          bits_left = budget-(int)ec_enc_tell(enc, 0)-2*C*(end-i);
164          if (bits_left < 24)
165          {
166             if (qi > 1)
167                qi = 1;
168             if (qi < -1)
169                qi = -1;
170             if (bits_left<8)
171                qi = 0;
172          }
173          ec_laplace_encode_start(enc, &qi, prob[2*i], prob[2*i+1]);
174          error[i+c*m->nbEBands] = PSHR32(f,15) - SHL16(qi,DB_SHIFT);
175          q = SHL16(qi,DB_SHIFT);
176          
177          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + prev[c] + SHL32(EXTEND32(q),15), 15);
178          prev[c] = prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
179       } while (++c < C);
180    }
181 }
182
183 void quant_coarse_energy(const CELTMode *m, int start, int end, int effEnd,
184       const celt_word16 *eBands, celt_word16 *oldEBands, int budget,
185       const celt_int16 *prob, celt_word16 *error, ec_enc *enc, int _C, int LM,
186       int nbAvailableBytes, int force_intra, int *delayedIntra, int two_pass)
187 {
188    const int C = CHANNELS(_C);
189    int intra;
190    celt_word16 max_decay;
191    VARDECL(celt_word16, oldEBands_intra);
192    VARDECL(celt_word16, error_intra);
193    ec_enc enc_start_state;
194    ec_byte_buffer buf_start_state;
195    SAVE_STACK;
196
197    intra = force_intra || (*delayedIntra && nbAvailableBytes > end);
198    if (/*shortBlocks || */intra_decision(eBands, oldEBands, start, effEnd, m->nbEBands, C))
199       *delayedIntra = 1;
200    else
201       *delayedIntra = 0;
202
203    /* Encode the global flags using a simple probability model
204       (first symbols in the stream) */
205
206 #ifdef FIXED_POINT
207       max_decay = MIN32(QCONST16(16,DB_SHIFT), SHL32(EXTEND32(nbAvailableBytes),DB_SHIFT-3));
208 #else
209    max_decay = MIN32(16.f, .125f*nbAvailableBytes);
210 #endif
211
212    enc_start_state = *enc;
213    buf_start_state = *(enc->buf);
214
215    ALLOC(oldEBands_intra, C*m->nbEBands, celt_word16);
216    ALLOC(error_intra, C*m->nbEBands, celt_word16);
217    CELT_COPY(oldEBands_intra, oldEBands, C*end);
218
219    if (two_pass || intra)
220    {
221       quant_coarse_energy_impl(m, start, end, eBands, oldEBands_intra, budget,
222             prob, error_intra, enc, C, LM, 1, max_decay);
223    }
224
225    if (!intra)
226    {
227       ec_enc enc_intra_state;
228       ec_byte_buffer buf_intra_state;
229       int tell_intra;
230       VARDECL(unsigned char, intra_bits);
231
232       tell_intra = ec_enc_tell(enc, 3);
233
234       enc_intra_state = *enc;
235       buf_intra_state = *(enc->buf);
236
237       ALLOC(intra_bits, buf_intra_state.ptr-buf_start_state.ptr, unsigned char);
238       /* Copy bits from intra bit-stream */
239       CELT_COPY(intra_bits, buf_start_state.ptr, buf_intra_state.ptr-buf_start_state.ptr);
240
241       *enc = enc_start_state;
242       *(enc->buf) = buf_start_state;
243
244       quant_coarse_energy_impl(m, start, end, eBands, oldEBands, budget,
245             prob, error, enc, C, LM, 0, max_decay);
246
247       if (two_pass && ec_enc_tell(enc, 3) > tell_intra)
248       {
249          *enc = enc_intra_state;
250          *(enc->buf) = buf_intra_state;
251          /* Copy bits from to bit-stream */
252          CELT_COPY(buf_start_state.ptr, intra_bits, buf_intra_state.ptr-buf_start_state.ptr);
253          CELT_COPY(oldEBands, oldEBands_intra, C*end);
254          CELT_COPY(error, error_intra, C*end);
255       }
256    } else {
257       CELT_COPY(oldEBands, oldEBands_intra, C*end);
258       CELT_COPY(error, error_intra, C*end);
259    }
260    RESTORE_STACK;
261 }
262
263 void quant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, ec_enc *enc, int _C)
264 {
265    int i, c;
266    const int C = CHANNELS(_C);
267
268    /* Encode finer resolution */
269    for (i=start;i<end;i++)
270    {
271       celt_int16 frac = 1<<fine_quant[i];
272       if (fine_quant[i] <= 0)
273          continue;
274       c=0;
275       do {
276          int q2;
277          celt_word16 offset;
278 #ifdef FIXED_POINT
279          /* Has to be without rounding */
280          q2 = (error[i+c*m->nbEBands]+QCONST16(.5f,DB_SHIFT))>>(DB_SHIFT-fine_quant[i]);
281 #else
282          q2 = (int)floor((error[i+c*m->nbEBands]+.5f)*frac);
283 #endif
284          if (q2 > frac-1)
285             q2 = frac-1;
286          if (q2<0)
287             q2 = 0;
288          ec_enc_bits(enc, q2, fine_quant[i]);
289 #ifdef FIXED_POINT
290          offset = SUB16(SHR16(SHL16(q2,DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
291 #else
292          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
293 #endif
294          oldEBands[i+c*m->nbEBands] += offset;
295          error[i+c*m->nbEBands] -= offset;
296          /*printf ("%f ", error[i] - offset);*/
297       } while (++c < C);
298    }
299 }
300
301 void quant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, int *fine_priority, int bits_left, ec_enc *enc, int _C)
302 {
303    int i, prio, c;
304    const int C = CHANNELS(_C);
305
306    /* Use up the remaining bits */
307    for (prio=0;prio<2;prio++)
308    {
309       for (i=start;i<end && bits_left>=C ;i++)
310       {
311          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
312             continue;
313          c=0;
314          do {
315             int q2;
316             celt_word16 offset;
317             q2 = error[i+c*m->nbEBands]<0 ? 0 : 1;
318             ec_enc_bits(enc, q2, 1);
319 #ifdef FIXED_POINT
320             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
321 #else
322             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
323 #endif
324             oldEBands[i+c*m->nbEBands] += offset;
325             bits_left--;
326          } while (++c < C);
327       }
328    }
329 }
330
331 void unquant_coarse_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int intra, const celt_int16 *prob, ec_dec *dec, int _C, int LM)
332 {
333    int i, c;
334    celt_word32 prev[2] = {0, 0};
335    celt_word16 coef;
336    celt_word16 beta;
337    const int C = CHANNELS(_C);
338
339    coef = pred_coef[LM];
340
341    if (intra)
342    {
343       coef = 0;
344       prob += 2*m->nbEBands;
345    }
346    /* No theoretical justification for this, it just works */
347    beta = MULT16_16_P15(coef,coef);
348
349    /* Decode at a fixed coarse resolution */
350    for (i=start;i<end;i++)
351    {
352       c=0;
353       do {
354          int qi;
355          celt_word16 q;
356          qi = ec_laplace_decode_start(dec, prob[2*i], prob[2*i+1]);
357          q = SHL16(qi,DB_SHIFT);
358
359          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + prev[c] + SHL32(EXTEND32(q),15), 15);
360          prev[c] = prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
361       } while (++c < C);
362    }
363 }
364
365 void unquant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant, ec_dec *dec, int _C)
366 {
367    int i, c;
368    const int C = CHANNELS(_C);
369    /* Decode finer resolution */
370    for (i=start;i<end;i++)
371    {
372       if (fine_quant[i] <= 0)
373          continue;
374       c=0; 
375       do {
376          int q2;
377          celt_word16 offset;
378          q2 = ec_dec_bits(dec, fine_quant[i]);
379 #ifdef FIXED_POINT
380          offset = SUB16(SHR16(SHL16(q2,DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
381 #else
382          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
383 #endif
384          oldEBands[i+c*m->nbEBands] += offset;
385       } while (++c < C);
386    }
387 }
388
389 void unquant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant,  int *fine_priority, int bits_left, ec_dec *dec, int _C)
390 {
391    int i, prio, c;
392    const int C = CHANNELS(_C);
393
394    /* Use up the remaining bits */
395    for (prio=0;prio<2;prio++)
396    {
397       for (i=start;i<end && bits_left>=C ;i++)
398       {
399          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
400             continue;
401          c=0;
402          do {
403             int q2;
404             celt_word16 offset;
405             q2 = ec_dec_bits(dec, 1);
406 #ifdef FIXED_POINT
407             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
408 #else
409             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
410 #endif
411             oldEBands[i+c*m->nbEBands] += offset;
412             bits_left--;
413          } while (++c < C);
414       }
415    }
416 }
417
418 void log2Amp(const CELTMode *m, int start, int end,
419       celt_ener *eBands, celt_word16 *oldEBands, int _C)
420 {
421    int c, i;
422    const int C = CHANNELS(_C);
423    c=0;
424    do {
425       for (i=start;i<m->nbEBands;i++)
426       {
427          celt_word16 lg = oldEBands[i+c*m->nbEBands]
428                         + SHL16((celt_word16)eMeans[i],6);
429          eBands[i+c*m->nbEBands] = PSHR32(celt_exp2(SHL16(lg,11-DB_SHIFT)),4);
430          if (oldEBands[i+c*m->nbEBands] < -QCONST16(14.f,DB_SHIFT))
431             oldEBands[i+c*m->nbEBands] = -QCONST16(14.f,DB_SHIFT);
432       }
433    } while (++c < C);
434 }
435
436 void amp2Log2(const CELTMode *m, int effEnd, int end,
437       celt_ener *bandE, celt_word16 *bandLogE, int _C)
438 {
439    int c, i;
440    const int C = CHANNELS(_C);
441    c=0;
442    do {
443       for (i=0;i<effEnd;i++)
444          bandLogE[i+c*m->nbEBands] =
445                celt_log2(MAX32(QCONST32(.001f,14),SHL32(bandE[i+c*m->nbEBands],2)))
446                - SHL16((celt_word16)eMeans[i],6);
447       for (i=effEnd;i<end;i++)
448          bandLogE[c*m->nbEBands+i] = -QCONST16(14.f,DB_SHIFT);
449    } while (++c < C);
450 }