Add coarse energy entropy model tuning.
[opus.git] / libcelt / quant_bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "quant_bands.h"
38 #include "laplace.h"
39 #include <math.h>
40 #include "os_support.h"
41 #include "arch.h"
42 #include "mathops.h"
43 #include "stack_alloc.h"
44
45 #ifdef FIXED_POINT
46 /* Mean energy in each band quantized in Q6 */
47 static const signed char eMeans[25] = {
48       103,100, 92, 85, 81,
49        77, 72, 70, 78, 75,
50        73, 71, 78, 74, 69,
51        72, 70, 74, 76, 71,
52        60, 60, 60, 60, 60
53 };
54 #else
55 /* Mean energy in each band quantized in Q6 and converted back to float */
56 static const celt_word16 eMeans[25] = {
57       6.437500f, 6.250000f, 5.750000f, 5.312500f, 5.062500f,
58       4.812500f, 4.500000f, 4.375000f, 4.875000f, 4.687500f,
59       4.562500f, 4.437500f, 4.875000f, 4.625000f, 4.312500f,
60       4.500000f, 4.375000f, 4.625000f, 4.750000f, 4.437500f,
61       3.750000f, 3.750000f, 3.750000f, 3.750000f, 3.750000f
62 };
63 #endif
64 /* prediction coefficients: 0.9, 0.8, 0.65, 0.5 */
65 #ifdef FIXED_POINT
66 static const celt_word16 pred_coef[4] = {29440, 26112, 21248, 16384};
67 static const celt_word16 beta_coef[4] = {30147, 22282, 12124, 6554};
68 #else
69 static const celt_word16 pred_coef[4] = {29440/32768., 26112/32768., 21248/32768., 16384/32768.};
70 static const celt_word16 beta_coef[4] = {30147/32768., 22282/32768., 12124/32768., 6554/32768.};
71 #endif
72
73 /*Parameters of the Laplace-like probability models used for the coarse energy.
74   There is one pair of parameters for each frame size, prediction type
75    (inter/intra), and band number.
76   The first number of each pair is the probability of 0, and the second is the
77    decay rate, both in Q8 precision.*/
78 static const unsigned char e_prob_model[4][2][42] = {
79    /*120 sample frames.*/
80    {
81       /*Inter*/
82       {
83           72, 127,  65, 129,  66, 128,  65, 128,  64, 128,  62, 128,  64, 128,
84           64, 128,  92,  78,  92,  79,  92,  78,  90,  79, 116,  41, 115,  40,
85          114,  40, 132,  26, 132,  26, 145,  17, 161,  12, 176,  10, 177,  11
86       },
87       /*Intra*/
88       {
89           24, 179,  48, 138,  54, 135,  54, 132,  53, 134,  56, 133,  55, 132,
90           55, 132,  61, 114,  70,  96,  74,  88,  75,  88,  87,  74,  89,  66,
91           91,  67, 100,  59, 108,  50, 120,  40, 122,  37,  97,  43,  78,  50
92       }
93    },
94    /*240 sample frames.*/
95    {
96       /*Inter*/
97       {
98           83,  78,  84,  81,  88,  75,  86,  74,  87,  71,  90,  73,  93,  74,
99           93,  74, 109,  40, 114,  36, 117,  34, 117,  34, 143,  17, 145,  18,
100          146,  19, 162,  12, 165,  10, 178,   7, 189,   6, 190,   8, 177,   9
101       },
102       /*Intra*/
103       {
104           23, 178,  54, 115,  63, 102,  66,  98,  69,  99,  74,  89,  71,  91,
105           73,  91,  78,  89,  86,  80,  92,  66,  93,  64, 102,  59, 103,  60,
106          104,  60, 117,  52, 123,  44, 138,  35, 133,  31,  97,  38,  77,  45
107       }
108    },
109    /*480 sample frames.*/
110    {
111       /*Inter*/
112       {
113           61,  90,  93,  60, 105,  42, 107,  41, 110,  45, 116,  38, 113,  38,
114          112,  38, 124,  26, 132,  27, 136,  19, 140,  20, 155,  14, 159,  16,
115          158,  18, 170,  13, 177,  10, 187,   8, 192,   6, 175,   9, 159,  10
116       },
117       /*Intra*/
118       {
119           21, 178,  59, 110,  71,  86,  75,  85,  84,  83,  91,  66,  88,  73,
120           87,  72,  92,  75,  98,  72, 105,  58, 107,  54, 115,  52, 114,  55,
121          112,  56, 129,  51, 132,  40, 150,  33, 140,  29,  98,  35,  77,  42
122       }
123    },
124    /*960 sample frames.*/
125    {
126       /*Inter*/
127       {
128           42, 121,  96,  66, 108,  43, 111,  40, 117,  44, 123,  32, 120,  36,
129          119,  33, 127,  33, 134,  34, 139,  21, 147,  23, 152,  20, 158,  25,
130          154,  26, 166,  21, 173,  16, 184,  13, 184,  10, 150,  13, 139,  15
131       },
132       /*Intra*/
133       {
134           22, 178,  63, 114,  74,  82,  84,  83,  92,  82, 103,  62,  96,  72,
135           96,  67, 101,  73, 107,  72, 113,  55, 118,  52, 125,  52, 118,  52,
136          117,  55, 135,  49, 137,  39, 157,  32, 145,  29,  97,  33,  77,  40
137       }
138    }
139 };
140
141 static int intra_decision(const celt_word16 *eBands, celt_word16 *oldEBands, int start, int end, int len, int C)
142 {
143    int c, i;
144    celt_word32 dist = 0;
145    c=0; do {
146       for (i=start;i<end;i++)
147       {
148          celt_word16 d = SHR16(SUB16(eBands[i+c*len], oldEBands[i+c*len]),2);
149          dist = MAC16_16(dist, d,d);
150       }
151    } while (++c<C);
152    return SHR32(dist,2*DB_SHIFT-4) > 2*C*(end-start);
153 }
154
155 static void quant_coarse_energy_impl(const CELTMode *m, int start, int end,
156       const celt_word16 *eBands, celt_word16 *oldEBands, int budget,
157       const unsigned char *prob_model, celt_word16 *error, ec_enc *enc,
158       int _C, int LM, int intra, celt_word16 max_decay)
159 {
160    const int C = CHANNELS(_C);
161    int i, c;
162    celt_word32 prev[2] = {0,0};
163    celt_word16 coef;
164    celt_word16 beta;
165
166    ec_enc_bit_prob(enc, intra, 8192);
167    if (intra)
168    {
169       coef = 0;
170       beta = QCONST16(.15f,15);
171    } else {
172       beta = beta_coef[LM];
173       coef = pred_coef[LM];
174    }
175
176    /* Encode at a fixed coarse resolution */
177    for (i=start;i<end;i++)
178    {
179       c=0;
180       do {
181          int bits_left;
182          int qi;
183          int pi;
184          celt_word16 q;
185          celt_word16 x;
186          celt_word32 f;
187          x = eBands[i+c*m->nbEBands];
188 #ifdef FIXED_POINT
189          f = SHL32(EXTEND32(x),15) -MULT16_16(coef,oldEBands[i+c*m->nbEBands])-prev[c];
190          /* Rounding to nearest integer here is really important! */
191          qi = (f+QCONST32(.5,DB_SHIFT+15))>>(DB_SHIFT+15);
192 #else
193          f = x-coef*oldEBands[i+c*m->nbEBands]-prev[c];
194          /* Rounding to nearest integer here is really important! */
195          qi = (int)floor(.5f+f);
196 #endif
197          /* Prevent the energy from going down too quickly (e.g. for bands
198             that have just one bin) */
199          if (qi < 0 && x < oldEBands[i+c*m->nbEBands]-max_decay)
200          {
201             qi += (int)SHR16(oldEBands[i+c*m->nbEBands]-max_decay-x, DB_SHIFT);
202             if (qi > 0)
203                qi = 0;
204          }
205          /* If we don't have enough bits to encode all the energy, just assume something safe.
206             We allow slightly busting the budget here */
207          bits_left = budget-(int)ec_enc_tell(enc, 0)-2*C*(end-i);
208          if (bits_left < 24)
209          {
210             if (qi > 1)
211                qi = 1;
212             if (qi < -1)
213                qi = -1;
214             if (bits_left<8)
215                qi = 0;
216          }
217          pi = 2*IMIN(i,20);
218          ec_laplace_encode(enc, &qi,
219                prob_model[pi]<<7, prob_model[pi+1]<<6);
220          error[i+c*m->nbEBands] = PSHR32(f,15) - SHL16(qi,DB_SHIFT);
221          q = SHL16(qi,DB_SHIFT);
222          
223          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + prev[c] + SHL32(EXTEND32(q),15), 15);
224          prev[c] = prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
225       } while (++c < C);
226    }
227 }
228
229 void quant_coarse_energy(const CELTMode *m, int start, int end, int effEnd,
230       const celt_word16 *eBands, celt_word16 *oldEBands, int budget,
231       celt_word16 *error, ec_enc *enc, int _C, int LM, int nbAvailableBytes,
232       int force_intra, int *delayedIntra, int two_pass)
233 {
234    const int C = CHANNELS(_C);
235    int intra;
236    celt_word16 max_decay;
237    VARDECL(celt_word16, oldEBands_intra);
238    VARDECL(celt_word16, error_intra);
239    ec_enc enc_start_state;
240    ec_byte_buffer buf_start_state;
241    SAVE_STACK;
242
243    intra = force_intra || (*delayedIntra && nbAvailableBytes > end);
244    if (/*shortBlocks || */intra_decision(eBands, oldEBands, start, effEnd, m->nbEBands, C))
245       *delayedIntra = 1;
246    else
247       *delayedIntra = 0;
248
249    /* Encode the global flags using a simple probability model
250       (first symbols in the stream) */
251
252 #ifdef FIXED_POINT
253       max_decay = MIN32(QCONST16(16,DB_SHIFT), SHL32(EXTEND32(nbAvailableBytes),DB_SHIFT-3));
254 #else
255    max_decay = MIN32(16.f, .125f*nbAvailableBytes);
256 #endif
257
258    enc_start_state = *enc;
259    buf_start_state = *(enc->buf);
260
261    ALLOC(oldEBands_intra, C*m->nbEBands, celt_word16);
262    ALLOC(error_intra, C*m->nbEBands, celt_word16);
263    CELT_COPY(oldEBands_intra, oldEBands, C*end);
264
265    if (two_pass || intra)
266    {
267       quant_coarse_energy_impl(m, start, end, eBands, oldEBands_intra, budget,
268             e_prob_model[LM][1], error_intra, enc, C, LM, 1, max_decay);
269    }
270
271    if (!intra)
272    {
273       ec_enc enc_intra_state;
274       ec_byte_buffer buf_intra_state;
275       int tell_intra;
276       VARDECL(unsigned char, intra_bits);
277
278       tell_intra = ec_enc_tell(enc, 3);
279
280       enc_intra_state = *enc;
281       buf_intra_state = *(enc->buf);
282
283       ALLOC(intra_bits, buf_intra_state.ptr-buf_start_state.ptr, unsigned char);
284       /* Copy bits from intra bit-stream */
285       CELT_COPY(intra_bits, buf_start_state.ptr, buf_intra_state.ptr-buf_start_state.ptr);
286
287       *enc = enc_start_state;
288       *(enc->buf) = buf_start_state;
289
290       quant_coarse_energy_impl(m, start, end, eBands, oldEBands, budget,
291             e_prob_model[LM][intra], error, enc, C, LM, 0, max_decay);
292
293       if (two_pass && ec_enc_tell(enc, 3) > tell_intra)
294       {
295          *enc = enc_intra_state;
296          *(enc->buf) = buf_intra_state;
297          /* Copy bits from to bit-stream */
298          CELT_COPY(buf_start_state.ptr, intra_bits, buf_intra_state.ptr-buf_start_state.ptr);
299          CELT_COPY(oldEBands, oldEBands_intra, C*end);
300          CELT_COPY(error, error_intra, C*end);
301       }
302    } else {
303       CELT_COPY(oldEBands, oldEBands_intra, C*end);
304       CELT_COPY(error, error_intra, C*end);
305    }
306    RESTORE_STACK;
307 }
308
309 void quant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, ec_enc *enc, int _C)
310 {
311    int i, c;
312    const int C = CHANNELS(_C);
313
314    /* Encode finer resolution */
315    for (i=start;i<end;i++)
316    {
317       celt_int16 frac = 1<<fine_quant[i];
318       if (fine_quant[i] <= 0)
319          continue;
320       c=0;
321       do {
322          int q2;
323          celt_word16 offset;
324 #ifdef FIXED_POINT
325          /* Has to be without rounding */
326          q2 = (error[i+c*m->nbEBands]+QCONST16(.5f,DB_SHIFT))>>(DB_SHIFT-fine_quant[i]);
327 #else
328          q2 = (int)floor((error[i+c*m->nbEBands]+.5f)*frac);
329 #endif
330          if (q2 > frac-1)
331             q2 = frac-1;
332          if (q2<0)
333             q2 = 0;
334          ec_enc_bits(enc, q2, fine_quant[i]);
335 #ifdef FIXED_POINT
336          offset = SUB16(SHR32(SHL32(EXTEND32(q2),DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
337 #else
338          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
339 #endif
340          oldEBands[i+c*m->nbEBands] += offset;
341          error[i+c*m->nbEBands] -= offset;
342          /*printf ("%f ", error[i] - offset);*/
343       } while (++c < C);
344    }
345 }
346
347 void quant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, int *fine_priority, int bits_left, ec_enc *enc, int _C)
348 {
349    int i, prio, c;
350    const int C = CHANNELS(_C);
351
352    /* Use up the remaining bits */
353    for (prio=0;prio<2;prio++)
354    {
355       for (i=start;i<end && bits_left>=C ;i++)
356       {
357          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
358             continue;
359          c=0;
360          do {
361             int q2;
362             celt_word16 offset;
363             q2 = error[i+c*m->nbEBands]<0 ? 0 : 1;
364             ec_enc_bits(enc, q2, 1);
365 #ifdef FIXED_POINT
366             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
367 #else
368             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
369 #endif
370             oldEBands[i+c*m->nbEBands] += offset;
371             bits_left--;
372          } while (++c < C);
373       }
374    }
375 }
376
377 void unquant_coarse_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int intra, ec_dec *dec, int _C, int LM)
378 {
379    const unsigned char *prob_model = e_prob_model[LM][intra];
380    int i, c;
381    celt_word32 prev[2] = {0, 0};
382    celt_word16 coef;
383    celt_word16 beta;
384    const int C = CHANNELS(_C);
385
386
387    if (intra)
388    {
389       coef = 0;
390       beta = QCONST16(.15f,15);
391    } else {
392       beta = beta_coef[LM];
393       coef = pred_coef[LM];
394    }
395
396    /* Decode at a fixed coarse resolution */
397    for (i=start;i<end;i++)
398    {
399       c=0;
400       do {
401          int qi;
402          int pi;
403          celt_word16 q;
404          pi = 2*IMIN(i,20);
405          qi = ec_laplace_decode(dec,
406                prob_model[pi]<<7, prob_model[pi+1]<<6);
407          q = SHL16(qi,DB_SHIFT);
408
409          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + prev[c] + SHL32(EXTEND32(q),15), 15);
410          prev[c] = prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
411       } while (++c < C);
412    }
413 }
414
415 void unquant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant, ec_dec *dec, int _C)
416 {
417    int i, c;
418    const int C = CHANNELS(_C);
419    /* Decode finer resolution */
420    for (i=start;i<end;i++)
421    {
422       if (fine_quant[i] <= 0)
423          continue;
424       c=0; 
425       do {
426          int q2;
427          celt_word16 offset;
428          q2 = ec_dec_bits(dec, fine_quant[i]);
429 #ifdef FIXED_POINT
430          offset = SUB16(SHR32(SHL32(EXTEND32(q2),DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
431 #else
432          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
433 #endif
434          oldEBands[i+c*m->nbEBands] += offset;
435       } while (++c < C);
436    }
437 }
438
439 void unquant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant,  int *fine_priority, int bits_left, ec_dec *dec, int _C)
440 {
441    int i, prio, c;
442    const int C = CHANNELS(_C);
443
444    /* Use up the remaining bits */
445    for (prio=0;prio<2;prio++)
446    {
447       for (i=start;i<end && bits_left>=C ;i++)
448       {
449          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
450             continue;
451          c=0;
452          do {
453             int q2;
454             celt_word16 offset;
455             q2 = ec_dec_bits(dec, 1);
456 #ifdef FIXED_POINT
457             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
458 #else
459             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
460 #endif
461             oldEBands[i+c*m->nbEBands] += offset;
462             bits_left--;
463          } while (++c < C);
464       }
465    }
466 }
467
468 void log2Amp(const CELTMode *m, int start, int end,
469       celt_ener *eBands, celt_word16 *oldEBands, int _C)
470 {
471    int c, i;
472    const int C = CHANNELS(_C);
473    c=0;
474    do {
475       for (i=start;i<m->nbEBands;i++)
476       {
477          celt_word16 lg = oldEBands[i+c*m->nbEBands]
478                         + SHL16((celt_word16)eMeans[i],6);
479          eBands[i+c*m->nbEBands] = PSHR32(celt_exp2(SHL16(lg,11-DB_SHIFT)),4);
480          if (oldEBands[i+c*m->nbEBands] < -QCONST16(14.f,DB_SHIFT))
481             oldEBands[i+c*m->nbEBands] = -QCONST16(14.f,DB_SHIFT);
482       }
483    } while (++c < C);
484 }
485
486 void amp2Log2(const CELTMode *m, int effEnd, int end,
487       celt_ener *bandE, celt_word16 *bandLogE, int _C)
488 {
489    int c, i;
490    const int C = CHANNELS(_C);
491    c=0;
492    do {
493       for (i=0;i<effEnd;i++)
494          bandLogE[i+c*m->nbEBands] =
495                celt_log2(MAX32(QCONST32(.001f,14),SHL32(bandE[i+c*m->nbEBands],2)))
496                - SHL16((celt_word16)eMeans[i],6);
497       for (i=effEnd;i<end;i++)
498          bandLogE[c*m->nbEBands+i] = -QCONST16(14.f,DB_SHIFT);
499    } while (++c < C);
500 }