Retrained coarse energy mean and beta coefficients
[opus.git] / libcelt / quant_bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "quant_bands.h"
38 #include "laplace.h"
39 #include <math.h>
40 #include "os_support.h"
41 #include "arch.h"
42 #include "mathops.h"
43 #include "stack_alloc.h"
44
45 #ifdef FIXED_POINT
46 /* Mean energy in each band quantized in Q6 */
47 const signed char eMeans[25] = {
48       92, 85, 76, 69, 65,
49       61, 56, 55, 63, 61,
50       59, 57, 65, 61, 57,
51       61, 59, 64, 66, 63,
52       54, 54, 54, 54, 54
53 };
54 #else
55 /* Mean energy in each band quantized in Q6 and converted back to float */
56 const celt_word16 eMeans[25] = {
57       5.750000f, 5.312500f, 4.750000f, 4.312500f, 4.062500f,
58       3.812500f, 3.500000f, 3.437500f, 3.937500f, 3.812500f,
59       3.687500f, 3.562500f, 4.062500f, 3.812500f, 3.562500f,
60       3.812500f, 3.687500f, 4.000000f, 4.125000f, 3.937500f,
61       3.375000f, 3.375000f, 3.375000f, 3.375000f, 3.375000f
62       };
63 #endif
64 /* prediction coefficients: 0.9, 0.8, 0.65, 0.5 */
65 #ifdef FIXED_POINT
66 static const celt_word16 pred_coef[4] = {29440, 26112, 21248, 16384};
67 static const celt_word16 beta_coef[4] = {30147, 22282, 12124, 6554};
68 #else
69 static const celt_word16 pred_coef[4] = {29440/32768., 26112/32768., 21248/32768., 16384/32768.};
70 static const celt_word16 beta_coef[4] = {30147/32768., 22282/32768., 12124/32768., 6554/32768.};
71 #endif
72
73 static int intra_decision(const celt_word16 *eBands, celt_word16 *oldEBands, int start, int end, int len, int C)
74 {
75    int c, i;
76    celt_word32 dist = 0;
77    for (c=0;c<C;c++)
78    {
79       for (i=start;i<end;i++)
80       {
81          celt_word16 d = SHR16(SUB16(eBands[i+c*len], oldEBands[i+c*len]),2);
82          dist = MAC16_16(dist, d,d);
83       }
84    }
85    return SHR32(dist,2*DB_SHIFT-4) > 2*C*(end-start);
86 }
87
88 #ifndef STATIC_MODES
89
90 celt_int16 *quant_prob_alloc(const CELTMode *m)
91 {
92    int i;
93    celt_int16 *prob;
94    prob = celt_alloc(4*m->nbEBands*sizeof(celt_int16));
95    if (prob==NULL)
96      return NULL;
97    for (i=0;i<m->nbEBands;i++)
98    {
99       prob[2*i] = 7000-i*200;
100       prob[2*i+1] = ec_laplace_get_start_freq(prob[2*i]);
101    }
102    for (i=0;i<m->nbEBands;i++)
103    {
104       prob[2*m->nbEBands+2*i] = 9000-i*220;
105       prob[2*m->nbEBands+2*i+1] = ec_laplace_get_start_freq(prob[2*m->nbEBands+2*i]);
106    }
107    return prob;
108 }
109
110 void quant_prob_free(const celt_int16 *freq)
111 {
112    celt_free((celt_int16*)freq);
113 }
114 #endif
115
116 static void quant_coarse_energy_impl(const CELTMode *m, int start, int end,
117       const celt_word16 *eBands, celt_word16 *oldEBands, int budget,
118       const celt_int16 *prob, celt_word16 *error, ec_enc *enc, int _C, int LM,
119       int intra, celt_word16 max_decay)
120 {
121    const int C = CHANNELS(_C);
122    int i, c;
123    celt_word32 prev[2] = {0,0};
124    celt_word16 coef;
125    celt_word16 beta;
126
127    ec_enc_bit_prob(enc, intra, 8192);
128    if (intra)
129    {
130       coef = 0;
131       prob += 2*m->nbEBands;
132       beta = QCONST16(.15f,15);
133    } else {
134       beta = beta_coef[LM];
135       coef = pred_coef[LM];
136    }
137
138    /* Encode at a fixed coarse resolution */
139    for (i=start;i<end;i++)
140    {
141       c=0;
142       do {
143          int bits_left;
144          int qi;
145          celt_word16 q;
146          celt_word16 x;
147          celt_word32 f;
148          x = eBands[i+c*m->nbEBands];
149 #ifdef FIXED_POINT
150          f = SHL32(EXTEND32(x),15) -MULT16_16(coef,oldEBands[i+c*m->nbEBands])-prev[c];
151          /* Rounding to nearest integer here is really important! */
152          qi = (f+QCONST32(.5,DB_SHIFT+15))>>(DB_SHIFT+15);
153 #else
154          f = x-coef*oldEBands[i+c*m->nbEBands]-prev[c];
155          /* Rounding to nearest integer here is really important! */
156          qi = (int)floor(.5f+f);
157 #endif
158          /* Prevent the energy from going down too quickly (e.g. for bands
159             that have just one bin) */
160          if (qi < 0 && x < oldEBands[i+c*m->nbEBands]-max_decay)
161          {
162             qi += (int)SHR16(oldEBands[i+c*m->nbEBands]-max_decay-x, DB_SHIFT);
163             if (qi > 0)
164                qi = 0;
165          }
166          /* If we don't have enough bits to encode all the energy, just assume something safe.
167             We allow slightly busting the budget here */
168          bits_left = budget-(int)ec_enc_tell(enc, 0)-2*C*(end-i);
169          if (bits_left < 24)
170          {
171             if (qi > 1)
172                qi = 1;
173             if (qi < -1)
174                qi = -1;
175             if (bits_left<8)
176                qi = 0;
177          }
178          ec_laplace_encode_start(enc, &qi, prob[2*i], prob[2*i+1]);
179          error[i+c*m->nbEBands] = PSHR32(f,15) - SHL16(qi,DB_SHIFT);
180          q = SHL16(qi,DB_SHIFT);
181          
182          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + prev[c] + SHL32(EXTEND32(q),15), 15);
183          prev[c] = prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
184       } while (++c < C);
185    }
186 }
187
188 void quant_coarse_energy(const CELTMode *m, int start, int end, int effEnd,
189       const celt_word16 *eBands, celt_word16 *oldEBands, int budget,
190       const celt_int16 *prob, celt_word16 *error, ec_enc *enc, int _C, int LM,
191       int nbAvailableBytes, int force_intra, int *delayedIntra, int two_pass)
192 {
193    const int C = CHANNELS(_C);
194    int intra;
195    celt_word16 max_decay;
196    VARDECL(celt_word16, oldEBands_intra);
197    VARDECL(celt_word16, error_intra);
198    ec_enc enc_start_state;
199    ec_byte_buffer buf_start_state;
200    SAVE_STACK;
201
202    intra = force_intra || (*delayedIntra && nbAvailableBytes > end);
203    if (/*shortBlocks || */intra_decision(eBands, oldEBands, start, effEnd, m->nbEBands, C))
204       *delayedIntra = 1;
205    else
206       *delayedIntra = 0;
207
208    /* Encode the global flags using a simple probability model
209       (first symbols in the stream) */
210
211 #ifdef FIXED_POINT
212       max_decay = MIN32(QCONST16(16,DB_SHIFT), SHL32(EXTEND32(nbAvailableBytes),DB_SHIFT-3));
213 #else
214    max_decay = MIN32(16.f, .125f*nbAvailableBytes);
215 #endif
216
217    enc_start_state = *enc;
218    buf_start_state = *(enc->buf);
219
220    ALLOC(oldEBands_intra, C*m->nbEBands, celt_word16);
221    ALLOC(error_intra, C*m->nbEBands, celt_word16);
222    CELT_COPY(oldEBands_intra, oldEBands, C*end);
223
224    if (two_pass || intra)
225    {
226       quant_coarse_energy_impl(m, start, end, eBands, oldEBands_intra, budget,
227             prob, error_intra, enc, C, LM, 1, max_decay);
228    }
229
230    if (!intra)
231    {
232       ec_enc enc_intra_state;
233       ec_byte_buffer buf_intra_state;
234       int tell_intra;
235       VARDECL(unsigned char, intra_bits);
236
237       tell_intra = ec_enc_tell(enc, 3);
238
239       enc_intra_state = *enc;
240       buf_intra_state = *(enc->buf);
241
242       ALLOC(intra_bits, buf_intra_state.ptr-buf_start_state.ptr, unsigned char);
243       /* Copy bits from intra bit-stream */
244       CELT_COPY(intra_bits, buf_start_state.ptr, buf_intra_state.ptr-buf_start_state.ptr);
245
246       *enc = enc_start_state;
247       *(enc->buf) = buf_start_state;
248
249       quant_coarse_energy_impl(m, start, end, eBands, oldEBands, budget,
250             prob, error, enc, C, LM, 0, max_decay);
251
252       if (two_pass && ec_enc_tell(enc, 3) > tell_intra)
253       {
254          *enc = enc_intra_state;
255          *(enc->buf) = buf_intra_state;
256          /* Copy bits from to bit-stream */
257          CELT_COPY(buf_start_state.ptr, intra_bits, buf_intra_state.ptr-buf_start_state.ptr);
258          CELT_COPY(oldEBands, oldEBands_intra, C*end);
259          CELT_COPY(error, error_intra, C*end);
260       }
261    } else {
262       CELT_COPY(oldEBands, oldEBands_intra, C*end);
263       CELT_COPY(error, error_intra, C*end);
264    }
265    RESTORE_STACK;
266 }
267
268 void quant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, ec_enc *enc, int _C)
269 {
270    int i, c;
271    const int C = CHANNELS(_C);
272
273    /* Encode finer resolution */
274    for (i=start;i<end;i++)
275    {
276       celt_int16 frac = 1<<fine_quant[i];
277       if (fine_quant[i] <= 0)
278          continue;
279       c=0;
280       do {
281          int q2;
282          celt_word16 offset;
283 #ifdef FIXED_POINT
284          /* Has to be without rounding */
285          q2 = (error[i+c*m->nbEBands]+QCONST16(.5f,DB_SHIFT))>>(DB_SHIFT-fine_quant[i]);
286 #else
287          q2 = (int)floor((error[i+c*m->nbEBands]+.5f)*frac);
288 #endif
289          if (q2 > frac-1)
290             q2 = frac-1;
291          if (q2<0)
292             q2 = 0;
293          ec_enc_bits(enc, q2, fine_quant[i]);
294 #ifdef FIXED_POINT
295          offset = SUB16(SHR32(SHL32(EXTEND32(q2),DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
296 #else
297          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
298 #endif
299          oldEBands[i+c*m->nbEBands] += offset;
300          error[i+c*m->nbEBands] -= offset;
301          /*printf ("%f ", error[i] - offset);*/
302       } while (++c < C);
303    }
304 }
305
306 void quant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, int *fine_priority, int bits_left, ec_enc *enc, int _C)
307 {
308    int i, prio, c;
309    const int C = CHANNELS(_C);
310
311    /* Use up the remaining bits */
312    for (prio=0;prio<2;prio++)
313    {
314       for (i=start;i<end && bits_left>=C ;i++)
315       {
316          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
317             continue;
318          c=0;
319          do {
320             int q2;
321             celt_word16 offset;
322             q2 = error[i+c*m->nbEBands]<0 ? 0 : 1;
323             ec_enc_bits(enc, q2, 1);
324 #ifdef FIXED_POINT
325             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
326 #else
327             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
328 #endif
329             oldEBands[i+c*m->nbEBands] += offset;
330             bits_left--;
331          } while (++c < C);
332       }
333    }
334 }
335
336 void unquant_coarse_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int intra, const celt_int16 *prob, ec_dec *dec, int _C, int LM)
337 {
338    int i, c;
339    celt_word32 prev[2] = {0, 0};
340    celt_word16 coef;
341    celt_word16 beta;
342    const int C = CHANNELS(_C);
343
344
345    if (intra)
346    {
347       coef = 0;
348       beta = QCONST16(.15f,15);
349       prob += 2*m->nbEBands;
350    } else {
351       beta = beta_coef[LM];
352       coef = pred_coef[LM];
353    }
354
355    /* Decode at a fixed coarse resolution */
356    for (i=start;i<end;i++)
357    {
358       c=0;
359       do {
360          int qi;
361          celt_word16 q;
362          qi = ec_laplace_decode_start(dec, prob[2*i], prob[2*i+1]);
363          q = SHL16(qi,DB_SHIFT);
364
365          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + prev[c] + SHL32(EXTEND32(q),15), 15);
366          prev[c] = prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
367       } while (++c < C);
368    }
369 }
370
371 void unquant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant, ec_dec *dec, int _C)
372 {
373    int i, c;
374    const int C = CHANNELS(_C);
375    /* Decode finer resolution */
376    for (i=start;i<end;i++)
377    {
378       if (fine_quant[i] <= 0)
379          continue;
380       c=0; 
381       do {
382          int q2;
383          celt_word16 offset;
384          q2 = ec_dec_bits(dec, fine_quant[i]);
385 #ifdef FIXED_POINT
386          offset = SUB16(SHR32(SHL32(EXTEND32(q2),DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
387 #else
388          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
389 #endif
390          oldEBands[i+c*m->nbEBands] += offset;
391       } while (++c < C);
392    }
393 }
394
395 void unquant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant,  int *fine_priority, int bits_left, ec_dec *dec, int _C)
396 {
397    int i, prio, c;
398    const int C = CHANNELS(_C);
399
400    /* Use up the remaining bits */
401    for (prio=0;prio<2;prio++)
402    {
403       for (i=start;i<end && bits_left>=C ;i++)
404       {
405          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
406             continue;
407          c=0;
408          do {
409             int q2;
410             celt_word16 offset;
411             q2 = ec_dec_bits(dec, 1);
412 #ifdef FIXED_POINT
413             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
414 #else
415             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
416 #endif
417             oldEBands[i+c*m->nbEBands] += offset;
418             bits_left--;
419          } while (++c < C);
420       }
421    }
422 }
423
424 void log2Amp(const CELTMode *m, int start, int end,
425       celt_ener *eBands, celt_word16 *oldEBands, int _C)
426 {
427    int c, i;
428    const int C = CHANNELS(_C);
429    c=0;
430    do {
431       for (i=start;i<m->nbEBands;i++)
432       {
433          celt_word16 lg = oldEBands[i+c*m->nbEBands]
434                         + SHL16((celt_word16)eMeans[i],6);
435          eBands[i+c*m->nbEBands] = PSHR32(celt_exp2(SHL16(lg,11-DB_SHIFT)),4);
436          if (oldEBands[i+c*m->nbEBands] < -QCONST16(14.f,DB_SHIFT))
437             oldEBands[i+c*m->nbEBands] = -QCONST16(14.f,DB_SHIFT);
438       }
439    } while (++c < C);
440 }
441
442 void amp2Log2(const CELTMode *m, int effEnd, int end,
443       celt_ener *bandE, celt_word16 *bandLogE, int _C)
444 {
445    int c, i;
446    const int C = CHANNELS(_C);
447    c=0;
448    do {
449       for (i=0;i<effEnd;i++)
450          bandLogE[i+c*m->nbEBands] =
451                celt_log2(MAX32(QCONST32(.001f,14),SHL32(bandE[i+c*m->nbEBands],2)))
452                - SHL16((celt_word16)eMeans[i],6);
453       for (i=effEnd;i<end;i++)
454          bandLogE[c*m->nbEBands+i] = -QCONST16(14.f,DB_SHIFT);
455    } while (++c < C);
456 }