Cleaning up intra_decision()
[opus.git] / libcelt / quant_bands.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "quant_bands.h"
38 #include "laplace.h"
39 #include <math.h>
40 #include "os_support.h"
41 #include "arch.h"
42 #include "mathops.h"
43 #include "stack_alloc.h"
44
45 #define E_MEANS_SIZE (3)
46
47 static const celt_word16 eMeans[E_MEANS_SIZE] = {QCONST16(7.5f,DB_SHIFT), -QCONST16(1.f,DB_SHIFT), -QCONST16(.5f,DB_SHIFT)};
48
49 /* prediction coefficients: 0.9, 0.8, 0.65, 0.5 */
50 #ifdef FIXED_POINT
51 static const celt_word16 pred_coef[4] = {29440, 26112, 21248, 16384};
52 #else
53 static const celt_word16 pred_coef[4] = {29440/32768., 26112/32768., 21248/32768., 16384/32768.};
54 #endif
55
56 int intra_decision(celt_word16 *eBands, celt_word16 *oldEBands, int start, int end, int len, int C)
57 {
58    int c, i;
59    celt_word32 dist = 0;
60    for (c=0;c<C;c++)
61    {
62       for (i=start;i<end;i++)
63       {
64          celt_word16 d = SUB16(eBands[i+c*len], oldEBands[i+c*len]);
65          dist = MAC16_16(dist, d,d);
66       }
67    }
68    return SHR32(dist,2*DB_SHIFT) > 2*C*(end-start);
69 }
70
71 int *quant_prob_alloc(const CELTMode *m)
72 {
73    int i;
74    int *prob;
75    prob = celt_alloc(4*m->nbEBands*sizeof(int));
76    if (prob==NULL)
77      return NULL;
78    for (i=0;i<m->nbEBands;i++)
79    {
80       prob[2*i] = 7000-i*200;
81       prob[2*i+1] = ec_laplace_get_start_freq(prob[2*i]);
82    }
83    for (i=0;i<m->nbEBands;i++)
84    {
85       prob[2*m->nbEBands+2*i] = 9000-i*220;
86       prob[2*m->nbEBands+2*i+1] = ec_laplace_get_start_freq(prob[2*m->nbEBands+2*i]);
87    }
88    return prob;
89 }
90
91 void quant_prob_free(int *freq)
92 {
93    celt_free(freq);
94 }
95
96 void quant_coarse_energy(const CELTMode *m, int start, int end, const celt_word16 *eBands, celt_word16 *oldEBands, int budget, int intra, int *prob, celt_word16 *error, ec_enc *enc, int _C, int LM, celt_word16 max_decay)
97 {
98    int i, c;
99    celt_word32 prev[2] = {0,0};
100    celt_word16 coef;
101    celt_word16 beta;
102    const int C = CHANNELS(_C);
103
104    coef = pred_coef[LM];
105
106    if (intra)
107    {
108       coef = 0;
109       prob += 2*m->nbEBands;
110    }
111    /* No theoretical justification for this, it just works */
112    beta = MULT16_16_P15(coef,coef);
113    /* Encode at a fixed coarse resolution */
114    for (i=start;i<end;i++)
115    {
116       c=0;
117       do {
118          int bits_left;
119          int qi;
120          celt_word16 q;
121          celt_word16 x;
122          celt_word32 f;
123          celt_word32 mean =  (i-start < E_MEANS_SIZE) ? SUB32(SHL32(EXTEND32(eMeans[i-start]),15), MULT16_16(coef,eMeans[i-start])) : 0;
124          x = eBands[i+c*m->nbEBands];
125 #ifdef FIXED_POINT
126          f = SHL32(EXTEND32(x),15)-mean -MULT16_16(coef,oldEBands[i+c*m->nbEBands])-prev[c];
127          /* Rounding to nearest integer here is really important! */
128          qi = (f+QCONST32(.5,DB_SHIFT+15))>>(DB_SHIFT+15);
129 #else
130          f = x-mean-coef*oldEBands[i+c*m->nbEBands]-prev[c];
131          /* Rounding to nearest integer here is really important! */
132          qi = (int)floor(.5f+f);
133 #endif
134          /* Prevent the energy from going down too quickly (e.g. for bands
135             that have just one bin) */
136          if (qi < 0 && x < oldEBands[i+c*m->nbEBands]-max_decay)
137          {
138             qi += SHR16(oldEBands[i+c*m->nbEBands]-max_decay-x, DB_SHIFT);
139             if (qi > 0)
140                qi = 0;
141          }
142          /* If we don't have enough bits to encode all the energy, just assume something safe.
143             We allow slightly busting the budget here */
144          bits_left = budget-(int)ec_enc_tell(enc, 0)-2*C*(end-i);
145          if (bits_left < 24)
146          {
147             if (qi > 1)
148                qi = 1;
149             if (qi < -1)
150                qi = -1;
151             if (bits_left<8)
152                qi = 0;
153          }
154          ec_laplace_encode_start(enc, &qi, prob[2*i], prob[2*i+1]);
155          error[i+c*m->nbEBands] = PSHR32(f,15) - SHL16(qi,DB_SHIFT);
156          q = SHL16(qi,DB_SHIFT);
157          
158          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + mean + prev[c] + SHL32(EXTEND32(q),15), 15);
159          prev[c] = mean + prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
160       } while (++c < C);
161    }
162 }
163
164 void quant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, ec_enc *enc, int _C)
165 {
166    int i, c;
167    const int C = CHANNELS(_C);
168
169    /* Encode finer resolution */
170    for (i=start;i<end;i++)
171    {
172       celt_int16 frac = 1<<fine_quant[i];
173       if (fine_quant[i] <= 0)
174          continue;
175       c=0;
176       do {
177          int q2;
178          celt_word16 offset;
179 #ifdef FIXED_POINT
180          /* Has to be without rounding */
181          q2 = (error[i+c*m->nbEBands]+QCONST16(.5f,DB_SHIFT))>>(DB_SHIFT-fine_quant[i]);
182 #else
183          q2 = (int)floor((error[i+c*m->nbEBands]+.5f)*frac);
184 #endif
185          if (q2 > frac-1)
186             q2 = frac-1;
187          if (q2<0)
188             q2 = 0;
189          ec_enc_bits(enc, q2, fine_quant[i]);
190 #ifdef FIXED_POINT
191          offset = SUB16(SHR16(SHL16(q2,DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
192 #else
193          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
194 #endif
195          oldEBands[i+c*m->nbEBands] += offset;
196          error[i+c*m->nbEBands] -= offset;
197          /*printf ("%f ", error[i] - offset);*/
198       } while (++c < C);
199    }
200 }
201
202 void quant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, celt_word16 *error, int *fine_quant, int *fine_priority, int bits_left, ec_enc *enc, int _C)
203 {
204    int i, prio, c;
205    const int C = CHANNELS(_C);
206
207    /* Use up the remaining bits */
208    for (prio=0;prio<2;prio++)
209    {
210       for (i=start;i<end && bits_left>=C ;i++)
211       {
212          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
213             continue;
214          c=0;
215          do {
216             int q2;
217             celt_word16 offset;
218             q2 = error[i+c*m->nbEBands]<0 ? 0 : 1;
219             ec_enc_bits(enc, q2, 1);
220 #ifdef FIXED_POINT
221             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
222 #else
223             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
224 #endif
225             oldEBands[i+c*m->nbEBands] += offset;
226             bits_left--;
227          } while (++c < C);
228       }
229    }
230    c=0;
231    do {
232       for (i=start;i<m->nbEBands;i++)
233       {
234          eBands[i+c*m->nbEBands] = log2Amp(oldEBands[i+c*m->nbEBands]);
235          if (oldEBands[i+c*m->nbEBands] < -QCONST16(7.f,DB_SHIFT))
236             oldEBands[i+c*m->nbEBands] = -QCONST16(7.f,DB_SHIFT);
237       }
238    } while (++c < C);
239 }
240
241 void unquant_coarse_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int intra, int *prob, ec_dec *dec, int _C, int LM)
242 {
243    int i, c;
244    celt_word32 prev[2] = {0, 0};
245    celt_word16 coef;
246    celt_word16 beta;
247    const int C = CHANNELS(_C);
248
249    coef = pred_coef[LM];
250
251    if (intra)
252    {
253       coef = 0;
254       prob += 2*m->nbEBands;
255    }
256    /* No theoretical justification for this, it just works */
257    beta = MULT16_16_P15(coef,coef);
258
259    /* Decode at a fixed coarse resolution */
260    for (i=start;i<end;i++)
261    {
262       c=0;
263       do {
264          int qi;
265          celt_word16 q;
266          celt_word32 mean =  (i-start < E_MEANS_SIZE) ? SUB32(SHL32(EXTEND32(eMeans[i-start]),15), MULT16_16(coef,eMeans[i-start])) : 0;
267          qi = ec_laplace_decode_start(dec, prob[2*i], prob[2*i+1]);
268          q = SHL16(qi,DB_SHIFT);
269
270          oldEBands[i+c*m->nbEBands] = PSHR32(MULT16_16(coef,oldEBands[i+c*m->nbEBands]) + mean + prev[c] + SHL32(EXTEND32(q),15), 15);
271          prev[c] = mean + prev[c] + SHL32(EXTEND32(q),15) - MULT16_16(beta,q);
272       } while (++c < C);
273    }
274 }
275
276 void unquant_fine_energy(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant, ec_dec *dec, int _C)
277 {
278    int i, c;
279    const int C = CHANNELS(_C);
280    /* Decode finer resolution */
281    for (i=start;i<end;i++)
282    {
283       if (fine_quant[i] <= 0)
284          continue;
285       c=0; 
286       do {
287          int q2;
288          celt_word16 offset;
289          q2 = ec_dec_bits(dec, fine_quant[i]);
290 #ifdef FIXED_POINT
291          offset = SUB16(SHR16(SHL16(q2,DB_SHIFT)+QCONST16(.5,DB_SHIFT),fine_quant[i]),QCONST16(.5f,DB_SHIFT));
292 #else
293          offset = (q2+.5f)*(1<<(14-fine_quant[i]))*(1.f/16384) - .5f;
294 #endif
295          oldEBands[i+c*m->nbEBands] += offset;
296       } while (++c < C);
297    }
298 }
299
300 void unquant_energy_finalise(const CELTMode *m, int start, int end, celt_ener *eBands, celt_word16 *oldEBands, int *fine_quant,  int *fine_priority, int bits_left, ec_dec *dec, int _C)
301 {
302    int i, prio, c;
303    const int C = CHANNELS(_C);
304
305    /* Use up the remaining bits */
306    for (prio=0;prio<2;prio++)
307    {
308       for (i=start;i<end && bits_left>=C ;i++)
309       {
310          if (fine_quant[i] >= 7 || fine_priority[i]!=prio)
311             continue;
312          c=0;
313          do {
314             int q2;
315             celt_word16 offset;
316             q2 = ec_dec_bits(dec, 1);
317 #ifdef FIXED_POINT
318             offset = SHR16(SHL16(q2,DB_SHIFT)-QCONST16(.5,DB_SHIFT),fine_quant[i]+1);
319 #else
320             offset = (q2-.5f)*(1<<(14-fine_quant[i]-1))*(1.f/16384);
321 #endif
322             oldEBands[i+c*m->nbEBands] += offset;
323             bits_left--;
324          } while (++c < C);
325       }
326    }
327    c=0;
328    do {
329       for (i=start;i<m->nbEBands;i++)
330       {
331          eBands[i+c*m->nbEBands] = log2Amp(oldEBands[i+c*m->nbEBands]);
332          if (oldEBands[i+c*m->nbEBands] < -QCONST16(7.f,DB_SHIFT))
333             oldEBands[i+c*m->nbEBands] = -QCONST16(7.f,DB_SHIFT);
334       }
335    } while (++c < C);
336 }