Refactor the entropy coder.
[opus.git] / libcelt / quant_bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "quant_bands.h"
38 #include "laplace.h"
39 #include <math.h>
40 #include "os_support.h"
41 #include "arch.h"
42 #include "mathops.h"
43 #include "stack_alloc.h"
44 #include "rate.h"
45
46 #ifdef FIXED_POINT
47 /* Mean energy in each band quantized in Q6 */
48 static const signed char eMeans[25] = {
49       103,100, 92, 85, 81,
50        77, 72, 70, 78, 75,
51        73, 71, 78, 74, 69,
52        72, 70, 74, 76, 71,
53        60, 60, 60, 60, 60
54 };
55 #else
56 /* Mean energy in each band quantized in Q6 and converted back to float */
57 static const celt_word16 eMeans[25] = {
58       6.437500f, 6.250000f, 5.750000f, 5.312500f, 5.062500f,
59       4.812500f, 4.500000f, 4.375000f, 4.875000f, 4.687500f,
60       4.562500f, 4.437500f, 4.875000f, 4.625000f, 4.312500f,
61       4.500000f, 4.375000f, 4.625000f, 4.750000f, 4.437500f,
62       3.750000f, 3.750000f, 3.750000f, 3.750000f, 3.750000f
63 };
64 #endif
65 /* prediction coefficients: 0.9, 0.8, 0.65, 0.5 */
66 #ifdef FIXED_POINT
67 static const celt_word16 pred_coef[4] = {29440, 26112, 21248, 16384};
68 static const celt_word16 beta_coef[4] = {30147, 22282, 12124, 6554};
69 static const celt_word16 beta_intra = 4915;
70 #else
71 static const celt_word16 pred_coef[4] = {29440/32768., 26112/32768., 21248/32768., 16384/32768.};
72 static const celt_word16 beta_coef[4] = {30147/32768., 22282/32768., 12124/32768., 6554/32768.};
73 static const celt_word16 beta_intra = 4915/32768.;
74 #endif
75
76 /*Parameters of the Laplace-like probability models used for the coarse energy.
77   There is one pair of parameters for each frame size, prediction type
78    (inter/intra), and band number.
79   The first number of each pair is the probability of 0, and the second is the
80    decay rate, both in Q8 precision.*/
81 static const unsigned char e_prob_model[4][2][42] = {
82    /*120 sample frames.*/
83    {
84       /*Inter*/
85       {
86           72, 127,  65, 129,  66, 128,  65, 128,  64, 128,  62, 128,  64, 128,
87           64, 128,  92,  78,  92,  79,  92,  78,  90,  79, 116,  41, 115,  40,
88          114,  40, 132,  26, 132,  26, 145,  17, 161,  12, 176,  10, 177,  11
89       },
90       /*Intra*/
91       {
92           24, 179,  48, 138,  54, 135,  54, 132,  53, 134,  56, 133,  55, 132,
93           55, 132,  61, 114,  70,  96,  74,  88,  75,  88,  87,  74,  89,  66,
94           91,  67, 100,  59, 108,  50, 120,  40, 122,  37,  97,  43,  78,  50
95       }
96    },
97    /*240 sample frames.*/
98    {
99       /*Inter*/
100       {
101           83,  78,  84,  81,  88,  75,  86,  74,  87,  71,  90,  73,  93,  74,
102           93,  74, 109,  40, 114,  36, 117,  34, 117,  34, 143,  17, 145,  18,
103          146,  19, 162,  12, 165,  10, 178,   7, 189,   6, 190,   8, 177,   9
104       },
105       /*Intra*/
106       {
107           23, 178,  54, 115,  63, 102,  66,  98,  69,  99,  74,  89,  71,  91,
108           73,  91,  78,  89,  86,  80,  92,  66,  93,  64, 102,  59, 103,  60,
109          104,  60, 117,  52, 123,  44, 138,  35, 133,  31,  97,  38,  77,  45
110       }
111    },
112    /*480 sample frames.*/
113    {
114       /*Inter*/
115       {
116           61,  90,  93,  60, 105,  42, 107,  41, 110,  45, 116,  38, 113,  38,
117          112,  38, 124,  26, 132,  27, 136,  19, 140,  20, 155,  14, 159,  16,
118          158,  18, 170,  13, 177,  10, 187,   8, 192,   6, 175,   9, 159,  10
119       },
120       /*Intra*/
121       {
122           21, 178,  59, 110,  71,  86,  75,  85,  84,  83,  91,  66,  88,  73,
123           87,  72,  92,  75,  98,  72, 105,  58, 107,  54, 115,  52, 114,  55,
124          112,  56, 129,  51, 132,  40, 150,  33, 140,  29,  98,  35,  77,  42
125       }
126    },
127    /*960 sample frames.*/
128    {
129       /*Inter*/
130       {
131           42, 121,  96,  66, 108,  43, 111,  40, 117,  44, 123,  32, 120,  36,
132          119,  33, 127,  33, 134,  34, 139,  21, 147,  23, 152,  20, 158,  25,
133          154,  26, 166,  21, 173,  16, 184,  13, 184,  10, 150,  13, 139,  15
134       },
135       /*Intra*/
136       {
137           22, 178,  63, 114,  74,  82,  84,  83,  92,  82, 103,  62,  96,  72,
138           96,  67, 101,  73, 107,  72, 113,  55, 118,  52, 125,  52, 118,  52,
139          117,  55, 135,  49, 137,  39, 157,  32, 145,  29,  97,  33,  77,  40
140       }
141    }
142 };
143
144 static const unsigned char small_energy_icdf[3]={2,1,0};
145
146 static int intra_decision(const celt_word16 *eBands, celt_word16 *oldEBands, int start, int end, int len, int C)
147 {
148    int c, i;
149    celt_word32 dist = 0;
150    c=0; do {
151       for (i=start;i<end;i++)
152       {
153          celt_word16 d = SHR16(SUB16(eBands[i+c*len], oldEBands[i+c*len]),2);
154          dist = MAC16_16(dist, d,d);
155       }
156    } while (++c<C);
157    return SHR32(dist,2*DB_SHIFT-4) > 2*C*(end-start);
158 }
159
160 static int quant_coarse_energy_impl(const CELTMode *m, int start, int end,
161       const celt_word16 *eBands, celt_word16 *oldEBands,
162       ec_int32 budget, ec_int32 tell,
163       const unsigned char *prob_model, celt_word16 *error, ec_enc *enc,
164       int _C, int LM, int intra, celt_word16 max_decay)
165 {
166    const int C = CHANNELS(_C);
167    int i, c;
168    int badness = 0;
169    celt_word32 prev[2] = {0,0};
170    celt_word16 coef;
171    celt_word16 beta;
172
173    if (tell+3 <= budget)
174       ec_enc_bit_logp(enc, intra, 3);
175    if (intra)
176    {
177       coef = 0;
178       beta = beta_intra;
179    } else {
180       beta = beta_coef[LM];
181       coef = pred_coef[LM];
182    }
183
184    /* Encode at a fixed coarse resolution */
185    for (i=start;i<end;i++)
186    {
187       c=0;
188       do {
189          int bits_left;
190          int qi, qi0;
191          celt_word32 q;
192          celt_word16 x;
193          celt_word32 f, tmp;
194          celt_word16 oldE;
195          celt_word16 decay_bound;
196          x = eBands[i+c*m->nbEBands];
197          oldE = MAX16(-QCONST16(9.f,DB_SHIFT), oldEBands[i+c*m->nbEBands]);
198 #ifdef FIXED_POINT
199          f = SHL32(EXTEND32(x),7) - PSHR32(MULT16_16(coef,oldE), 8) - prev[c];
200          /* Rounding to nearest integer here is really important! */
201          qi = (f+QCONST32(.5f,DB_SHIFT+7))>>(DB_SHIFT+7);
202          decay_bound = EXTRACT16(MAX32(-QCONST16(28.f,DB_SHIFT),
203                SUB32((celt_word32)oldEBands[i+c*m->nbEBands],max_decay)));
204 #else
205          f = x-coef*oldE-prev[c];
206          /* Rounding to nearest integer here is really important! */
207          qi = (int)floor(.5f+f);
208          decay_bound = MAX16(-QCONST16(28.f,DB_SHIFT), oldEBands[i+c*m->nbEBands]) - max_decay;
209 #endif
210          /* Prevent the energy from going down too quickly (e.g. for bands
211             that have just one bin) */
212          if (qi < 0 && x < decay_bound)
213          {
214             qi += (int)SHR16(SUB16(decay_bound,x), DB_SHIFT);
215             if (qi > 0)
216                qi = 0;
217          }
218          qi0 = qi;
219          /* If we don't have enough bits to encode all the energy, just assume
220              something safe. */
221          tell = ec_tell(enc);
222          bits_left = budget-tell-3*C*(end-i);
223          if (i!=start && bits_left < 30)
224          {
225             if (bits_left < 24)
226                qi = IMIN(1, qi);
227             if (bits_left < 16)
228                qi = IMAX(-1, qi);
229          }
230          if (budget-tell >= 15)
231          {
232             int pi;
233             pi = 2*IMIN(i,20);
234             ec_laplace_encode(enc, &qi,
235                   prob_model[pi]<<7, prob_model[pi+1]<<6);
236          }
237          else if(budget-tell >= 2)
238          {
239             qi = IMAX(-1, IMIN(qi, 1));
240             ec_enc_icdf(enc, 2*qi^-(qi<0), small_energy_icdf, 2);
241          }
242          else if(budget-tell >= 1)
243          {
244             qi = IMIN(0, qi);
245             ec_enc_bit_logp(enc, -qi, 1);
246          }
247          else
248             qi = -1;
249          error[i+c*m->nbEBands] = PSHR32(f,7) - SHL16(qi,DB_SHIFT);
250          badness += abs(qi0-qi);
251          q = SHL32(EXTEND32(qi),DB_SHIFT);
252          
253          tmp = PSHR32(MULT16_16(coef,oldE),8) + prev[c] + SHL32(q,7);
254 #ifdef FIXED_POINT
255          tmp = MAX32(-QCONST32(28.f, DB_SHIFT+7), tmp);
256 #endif
257          oldEBands[i+c*m->nbEBands] = PSHR32(tmp, 7);
258          prev[c] = prev[c] + SHL32(q,7) - MULT16_16(beta,PSHR32(q,8));
259       } while (++c < C);
260    }
261    return badness;
262 }
263
264 void quant_coarse_energy(const CELTMode *m, int start, int end, int effEnd,
265       const celt_word16 *eBands, celt_word16 *oldEBands, ec_uint32 budget,
266       celt_word16 *error, ec_enc *enc, int _C, int LM, int nbAvailableBytes,
267       int force_intra, int *delayedIntra, int two_pass)
268 {
269    const int C = CHANNELS(_C);
270    int intra;
271    celt_word16 max_decay;
272    VARDECL(celt_word16, oldEBands_intra);
273    VARDECL(celt_word16, error_intra);
274    ec_enc enc_start_state;
275    ec_uint32 tell;
276    int badness1=0;
277    SAVE_STACK;
278
279    intra = force_intra || (*delayedIntra && nbAvailableBytes > end*C);
280    if (/*shortBlocks || */intra_decision(eBands, oldEBands, start, effEnd, m->nbEBands, C))
281       *delayedIntra = 1;
282    else
283       *delayedIntra = 0;
284
285    tell = ec_tell(enc);
286    if (tell+3 > budget)
287       two_pass = intra = 0;
288
289    /* Encode the global flags using a simple probability model
290       (first symbols in the stream) */
291
292 #ifdef FIXED_POINT
293       max_decay = MIN32(QCONST16(16.f,DB_SHIFT), SHL32(EXTEND32(nbAvailableBytes),DB_SHIFT-3));
294 #else
295    max_decay = MIN32(16.f, .125f*nbAvailableBytes);
296 #endif
297
298    enc_start_state = *enc;
299
300    ALLOC(oldEBands_intra, C*m->nbEBands, celt_word16);
301    ALLOC(error_intra, C*m->nbEBands, celt_word16);
302    CELT_COPY(oldEBands_intra, oldEBands, C*end);
303
304    if (two_pass || intra)
305    {
306       badness1 = quant_coarse_energy_impl(m, start, end, eBands, oldEBands_intra, budget,
307             tell, e_prob_model[LM][1], error_intra, enc, C, LM, 1, max_decay);
308    }
309
310    if (!intra)
311    {
312       ec_enc enc_intra_state;
313       int tell_intra;
314       ec_uint32 nstart_bytes;
315       ec_uint32 nintra_bytes;
316       int badness2;
317       VARDECL(unsigned char, intra_bits);
318
319       tell_intra = ec_tell_frac(enc);
320
321       enc_intra_state = *enc;
322
323       nstart_bytes = ec_range_bytes(&enc_start_state);
324       nintra_bytes = ec_range_bytes(&enc_intra_state);
325       ALLOC(intra_bits, nintra_bytes-nstart_bytes, unsigned char);
326       /* Copy bits from intra bit-stream */
327       CELT_COPY(intra_bits,
328             ec_get_buffer(&enc_intra_state) + nstart_bytes,
329             nintra_bytes - nstart_bytes);
330
331       *enc = enc_start_state;
332
333       badness2 = quant_coarse_energy_impl(m, start, end, eBands, oldEBands, budget,
334             tell, e_prob_model[LM][intra], error, enc, C, LM, 0, max_decay);
335
336       if (two_pass && (badness1 < badness2 || (badness1 == badness2 && ec_tell_frac(enc) > tell_intra)))
337       {
338          *enc = enc_intra_state;
339          /* Copy intra bits to bit-stream */
340          CELT_COPY(ec_get_buffer(&enc_intra_state) + nstart_bytes,
341                intra_bits, nintra_bytes - nstart_bytes);
342          CELT_COPY(oldEBands, oldEBands_intra, C*end);
343          CELT_COPY(error, error_intra, C*end);
344       }
345    } else {
346       CELT_COPY(oldEBands, oldEBands_intra, C*end);
347       CELT_COPY(error, error_intra, C*end);
348    }
349    RESTORE_STACK;
350 }
351
352 void quant_fine_energy(const CELTMode *m, int start, int end, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, ec_enc *enc, int _C)
353 {
354    int i, c;
355    const int C = CHANNELS(_C);
356
357    /* Encode finer resolution */
358    for (i=start;i<end;i++)
359    {
360       celt_int16 frac = 1<<fine_quant[i];
361       if (fine_quant[i] <= 0)
362          continue;
363       c=0;
364       do {
365          int q2;
366          celt_word16 offset;
367 #ifdef FIXED_POINT
368          /* Has to be without rounding */
369          q2 = (error[i+c*m->nbEBands]+QCONST16(.5f,DB_SHIFT))>>(DB_SHIFT-fine_quant[i]);
370 #else
371          q2 = (int)floor((error[i+c*m->nbEBands]+.5f)*frac);
372 #endif
373          if (q2 > frac-1)
374             q2 = frac-1;
375          if (q2<0)
376             q2 = 0;
377          ec_enc_bits(enc, q2, fine_quant[i]);
378 #ifdef FIXED_POINT
379          offset = SUB16(SHR32(SHL32(EXTEND32(q2),DB_SHIFT)+QCONST16(.5f,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
380 #else
381          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
382 #endif
383          oldEBands[i+c*m->nbEBands] += offset;
384          error[i+c*m->nbEBands] -= offset;
385          /*printf ("%f ", error[i] - offset);*/
386       } while (++c < C);
387    }
388 }
389
390 void quant_energy_finalise(const CELTMode *m, int start, int end, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, int *fine_priority, int bits_left, ec_enc *enc, int _C)
391 {
392    int i, prio, c;
393    const int C = CHANNELS(_C);
394
395    /* Use up the remaining bits */
396    for (prio=0;prio<2;prio++)
397    {
398       for (i=start;i<end && bits_left>=C ;i++)
399       {
400          if (fine_quant[i] >= MAX_FINE_BITS || fine_priority[i]!=prio)
401             continue;
402          c=0;
403          do {
404             int q2;
405             celt_word16 offset;
406             q2 = error[i+c*m->nbEBands]<0 ? 0 : 1;
407             ec_enc_bits(enc, q2, 1);
408 #ifdef FIXED_POINT
409             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5f,DB_SHIFT),fine_quant[i]+1);
410 #else
411             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
412 #endif
413             oldEBands[i+c*m->nbEBands] += offset;
414             bits_left--;
415          } while (++c < C);
416       }
417    }
418 }
419
420 void unquant_coarse_energy(const CELTMode *m, int start, int end, celt_word16 *oldEBands, int intra, ec_dec *dec, int _C, int LM)
421 {
422    const unsigned char *prob_model = e_prob_model[LM][intra];
423    int i, c;
424    celt_word32 prev[2] = {0, 0};
425    celt_word16 coef;
426    celt_word16 beta;
427    const int C = CHANNELS(_C);
428    ec_int32 budget;
429    ec_int32 tell;
430
431
432    if (intra)
433    {
434       coef = 0;
435       beta = beta_intra;
436    } else {
437       beta = beta_coef[LM];
438       coef = pred_coef[LM];
439    }
440
441    budget = dec->storage*8;
442
443    /* Decode at a fixed coarse resolution */
444    for (i=start;i<end;i++)
445    {
446       c=0;
447       do {
448          int qi;
449          celt_word32 q;
450          celt_word32 tmp;
451          tell = ec_tell(dec);
452          if(budget-tell>=15)
453          {
454             int pi;
455             pi = 2*IMIN(i,20);
456             qi = ec_laplace_decode(dec,
457                   prob_model[pi]<<7, prob_model[pi+1]<<6);
458          }
459          else if(budget-tell>=2)
460          {
461             qi = ec_dec_icdf(dec, small_energy_icdf, 2);
462             qi = (qi>>1)^-(qi&1);
463          }
464          else if(budget-tell>=1)
465          {
466             qi = -ec_dec_bit_logp(dec, 1);
467          }
468          else
469             qi = -1;
470          q = SHL32(EXTEND32(qi),DB_SHIFT);
471
472          oldEBands[i+c*m->nbEBands] = MAX16(-QCONST16(9.f,DB_SHIFT), oldEBands[i+c*m->nbEBands]);
473          tmp = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]),8) + prev[c] + SHL32(q,7);
474 #ifdef FIXED_POINT
475          tmp = MAX32(-QCONST32(28.f, DB_SHIFT+7), tmp);
476 #endif
477          oldEBands[i+c*m->nbEBands] = PSHR32(tmp, 7);
478          prev[c] = prev[c] + SHL32(q,7) - MULT16_16(beta,PSHR32(q,8));
479       } while (++c < C);
480    }
481 }
482
483 void unquant_fine_energy(const CELTMode *m, int start, int end, celt_word16 *oldEBands, int *fine_quant, ec_dec *dec, int _C)
484 {
485    int i, c;
486    const int C = CHANNELS(_C);
487    /* Decode finer resolution */
488    for (i=start;i<end;i++)
489    {
490       if (fine_quant[i] <= 0)
491          continue;
492       c=0; 
493       do {
494          int q2;
495          celt_word16 offset;
496          q2 = ec_dec_bits(dec, fine_quant[i]);
497 #ifdef FIXED_POINT
498          offset = SUB16(SHR32(SHL32(EXTEND32(q2),DB_SHIFT)+QCONST16(.5f,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
499 #else
500          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
501 #endif
502          oldEBands[i+c*m->nbEBands] += offset;
503       } while (++c < C);
504    }
505 }
506
507 void unquant_energy_finalise(const CELTMode *m, int start, int end, celt_word16 *oldEBands, int *fine_quant,  int *fine_priority, int bits_left, ec_dec *dec, int _C)
508 {
509    int i, prio, c;
510    const int C = CHANNELS(_C);
511
512    /* Use up the remaining bits */
513    for (prio=0;prio<2;prio++)
514    {
515       for (i=start;i<end && bits_left>=C ;i++)
516       {
517          if (fine_quant[i] >= MAX_FINE_BITS || fine_priority[i]!=prio)
518             continue;
519          c=0;
520          do {
521             int q2;
522             celt_word16 offset;
523             q2 = ec_dec_bits(dec, 1);
524 #ifdef FIXED_POINT
525             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5f,DB_SHIFT),fine_quant[i]+1);
526 #else
527             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
528 #endif
529             oldEBands[i+c*m->nbEBands] += offset;
530             bits_left--;
531          } while (++c < C);
532       }
533    }
534 }
535
536 void log2Amp(const CELTMode *m, int start, int end,
537       celt_ener *eBands, celt_word16 *oldEBands, int _C)
538 {
539    int c, i;
540    const int C = CHANNELS(_C);
541    c=0;
542    do {
543       for (i=0;i<start;i++)
544          eBands[i+c*m->nbEBands] = 0;
545       for (;i<end;i++)
546       {
547          celt_word16 lg = ADD16(oldEBands[i+c*m->nbEBands],
548                          SHL16((celt_word16)eMeans[i],6));
549          eBands[i+c*m->nbEBands] = PSHR32(celt_exp2(lg),4);
550       }
551       for (;i<m->nbEBands;i++)
552          eBands[i+c*m->nbEBands] = 0;
553    } while (++c < C);
554 }
555
556 void amp2Log2(const CELTMode *m, int effEnd, int end,
557       celt_ener *bandE, celt_word16 *bandLogE, int _C)
558 {
559    int c, i;
560    const int C = CHANNELS(_C);
561    c=0;
562    do {
563       for (i=0;i<effEnd;i++)
564          bandLogE[i+c*m->nbEBands] =
565                celt_log2(SHL32(bandE[i+c*m->nbEBands],2))
566                - SHL16((celt_word16)eMeans[i],6);
567       for (i=effEnd;i<end;i++)
568          bandLogE[c*m->nbEBands+i] = -QCONST16(14.f,DB_SHIFT);
569    } while (++c < C);
570 }