Adjusting/fixing warnings
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
45    applies the compression in the pitch direction, computes the gain that will
46    give ||p+g*y||=1 and mixes the residual with the pitch. */
47 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
48 {
49    int i;
50    celt_word32_t Ryp, Ryy, Rpp;
51    celt_word32_t g;
52    VARDECL(celt_norm_t, y);
53 #ifdef FIXED_POINT
54    int yshift;
55 #endif
56    SAVE_STACK;
57 #ifdef FIXED_POINT
58    yshift = 14-EC_ILOG(K);
59 #endif
60    ALLOC(y, N, celt_norm_t);
61
62    /*for (i=0;i<N;i++)
63    printf ("%d ", iy[i]);*/
64    Rpp = 0;
65    for (i=0;i<N;i++)
66       Rpp = MAC16_16(Rpp,P[i],P[i]);
67
68    Ryp = 0;
69    for (i=0;i<N;i++)
70       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
71
72    /* Remove part of the pitch component to compute the real residual from
73       the encoded (int) one */
74    for (i=0;i<N;i++)
75       y[i] = SUB16(SHL16(iy[i],yshift),
76                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
77
78    /* Recompute after the projection (I think it's right) */
79    Ryp = 0;
80    for (i=0;i<N;i++)
81       Ryp = MAC16_16(Ryp,y[i],P[i]);
82
83    Ryy = 0;
84    for (i=0;i<N;i++)
85       Ryy = MAC16_16(Ryy, y[i],y[i]);
86
87    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
88    g = MULT16_32_Q15(
89             celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
90                       MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14)))
91             - ROUND(Ryp,14),
92        celt_rcp(SHR32(Ryy,9)));
93
94    for (i=0;i<N;i++)
95       X[i] = P[i] + ROUND(MULT16_16(y[i], g),11);
96    RESTORE_STACK;
97 }
98
99 /** All the info necessary to keep track of a hypothesis during the search */
100 struct NBest {
101    celt_word32_t score;
102    int sign;
103    int pos;
104    int orig;
105    celt_word32_t xy;
106    celt_word32_t yy;
107    celt_word32_t yp;
108 };
109
110 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
111 {
112    int L = 3;
113    VARDECL(celt_norm_t, _y);
114    VARDECL(celt_norm_t, _ny);
115    VARDECL(int, _iy);
116    VARDECL(int, _iny);
117    VARDECL(celt_norm_t *, y);
118    VARDECL(celt_norm_t *, ny);
119    VARDECL(int *, iy);
120    VARDECL(int *, iny);
121    int i, j, k, m;
122    int pulsesLeft;
123    VARDECL(celt_word32_t, xy);
124    VARDECL(celt_word32_t, yy);
125    VARDECL(celt_word32_t, yp);
126    VARDECL(struct NBest, _nbest);
127    VARDECL(struct NBest *, nbest);
128    celt_word32_t Rpp=0, Rxp=0;
129    int maxL = 1;
130 #ifdef FIXED_POINT
131    int yshift;
132 #endif
133    SAVE_STACK;
134
135 #ifdef FIXED_POINT
136    yshift = 14-EC_ILOG(K);
137 #endif
138
139    ALLOC(_y, L*N, celt_norm_t);
140    ALLOC(_ny, L*N, celt_norm_t);
141    ALLOC(_iy, L*N, int);
142    ALLOC(_iny, L*N, int);
143    ALLOC(y, L, celt_norm_t*);
144    ALLOC(ny, L, celt_norm_t*);
145    ALLOC(iy, L, int*);
146    ALLOC(iny, L, int*);
147    
148    ALLOC(xy, L, celt_word32_t);
149    ALLOC(yy, L, celt_word32_t);
150    ALLOC(yp, L, celt_word32_t);
151    ALLOC(_nbest, L, struct NBest);
152    ALLOC(nbest, L, struct NBest *);
153    
154    for (m=0;m<L;m++)
155       nbest[m] = &_nbest[m];
156    
157    for (m=0;m<L;m++)
158    {
159       ny[m] = &_ny[m*N];
160       iny[m] = &_iny[m*N];
161       y[m] = &_y[m*N];
162       iy[m] = &_iy[m*N];
163    }
164    
165    for (j=0;j<N;j++)
166    {
167       Rpp = MAC16_16(Rpp, P[j],P[j]);
168       Rxp = MAC16_16(Rxp, X[j],P[j]);
169    }
170    Rpp = ROUND(Rpp, NORM_SHIFT);
171    Rxp = ROUND(Rxp, NORM_SHIFT);
172    if (Rpp>NORM_SCALING)
173       celt_fatal("Rpp > 1");
174
175    /* We only need to initialise the zero because the first iteration only uses that */
176    for (i=0;i<N;i++)
177       y[0][i] = 0;
178    for (i=0;i<N;i++)
179       iy[0][i] = 0;
180    xy[0] = yy[0] = yp[0] = 0;
181
182    pulsesLeft = K;
183    while (pulsesLeft > 0)
184    {
185       int pulsesAtOnce=1;
186       int Lupdate = L;
187       int L2 = L;
188       
189       /* Decide on complexity strategy */
190       pulsesAtOnce = pulsesLeft/N;
191       if (pulsesAtOnce<1)
192          pulsesAtOnce = 1;
193       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
194          Lupdate = 1;
195       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
196       L2 = Lupdate;
197       if (L2>maxL)
198       {
199          L2 = maxL;
200          maxL *= N;
201       }
202
203       for (m=0;m<Lupdate;m++)
204          nbest[m]->score = -VERY_LARGE32;
205
206       for (m=0;m<L2;m++)
207       {
208          for (j=0;j<N;j++)
209          {
210             int sign;
211             /*if (x[j]>0) sign=1; else sign=-1;*/
212             for (sign=-1;sign<=1;sign+=2)
213             {
214                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
215                celt_word32_t Rxy, Ryy, Ryp;
216                celt_word16_t spj, aspj; /* Intermediate results */
217                celt_word32_t score;
218                celt_word32_t g;
219                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
220                
221                /* All pulses at one location must have the same sign. */
222                if (iy[m][j]*sign < 0)
223                   continue;
224
225                spj = MULT16_16_Q14(s, P[j]);
226                aspj = MULT16_16_Q15(alpha, spj);
227                /* Updating the sums of the new pulse(s) */
228                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_Q15(alpha,spj),Rxp);
229                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
230                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
231                
232                /* Compute the gain such that ||p + g*y|| = 1 */
233                g = MULT16_32_Q15(
234                         celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
235                                   MULT16_16(ROUND(Ryy,14),Rpp))
236                         - ROUND(Ryp,14),
237                    celt_rcp(SHR32(Ryy,12)));
238                /* Knowing that gain, what's the error: (x-g*y)^2 
239                   (result is negated and we discard x^2 because it's constant) */
240                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
241                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g)
242                        - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
243
244                if (score>nbest[Lupdate-1]->score)
245                {
246                   int id = Lupdate-1;
247                   struct NBest *tmp_best;
248
249                   /* Save some pointers that would be deleted and use them for the current entry*/
250                   tmp_best = nbest[Lupdate-1];
251                   while (id > 0 && score > nbest[id-1]->score)
252                      id--;
253                
254                   for (k=Lupdate-1;k>id;k--)
255                      nbest[k] = nbest[k-1];
256
257                   nbest[id] = tmp_best;
258                   nbest[id]->score = score;
259                   nbest[id]->pos = j;
260                   nbest[id]->orig = m;
261                   nbest[id]->sign = sign;
262                   nbest[id]->xy = Rxy;
263                   nbest[id]->yy = Ryy;
264                   nbest[id]->yp = Ryp;
265                }
266             }
267          }
268
269       }
270       
271       if (!(nbest[0]->score > -VERY_LARGE32))
272          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
273       /* Only now that we've made the final choice, update ny/iny and others */
274       for (k=0;k<Lupdate;k++)
275       {
276          int n;
277          int is;
278          celt_norm_t s;
279          is = nbest[k]->sign*pulsesAtOnce;
280          s = SHL16(is, yshift);
281          for (n=0;n<N;n++)
282             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
283          ny[k][nbest[k]->pos] += s;
284
285          for (n=0;n<N;n++)
286             iny[k][n] = iy[nbest[k]->orig][n];
287          iny[k][nbest[k]->pos] += is;
288
289          xy[k] = nbest[k]->xy;
290          yy[k] = nbest[k]->yy;
291          yp[k] = nbest[k]->yp;
292       }
293       /* Swap ny/iny with y/iy */
294       for (k=0;k<Lupdate;k++)
295       {
296          celt_norm_t *tmp_ny;
297          int *tmp_iny;
298
299          tmp_ny = ny[k];
300          ny[k] = y[k];
301          y[k] = tmp_ny;
302          tmp_iny = iny[k];
303          iny[k] = iy[k];
304          iy[k] = tmp_iny;
305       }
306       pulsesLeft -= pulsesAtOnce;
307    }
308    
309 #if 0
310    if (0) {
311       celt_word32_t err=0;
312       for (i=0;i<N;i++)
313          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
314       /*if (N<=10)
315         printf ("%f %d %d\n", err, K, N);*/
316    }
317    /* Sanity checks, don't bother */
318    if (0) {
319       for (i=0;i<N;i++)
320          x[i] = p[i]+nbest[0]->gain*y[0][i];
321       celt_word32_t E=1e-15;
322       int ABS = 0;
323       for (i=0;i<N;i++)
324          ABS += abs(iy[0][i]);
325       /*if (K != ABS)
326          printf ("%d %d\n", K, ABS);*/
327       for (i=0;i<N;i++)
328          E += x[i]*x[i];
329       /*printf ("%f\n", E);*/
330       E = 1/sqrt(E);
331       for (i=0;i<N;i++)
332          x[i] *= E;
333    }
334 #endif
335    
336    encode_pulses(iy[0], N, K, enc);
337    
338    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
339       due to the recursive computation used in quantisation.
340       Not quite sure whether we need that or not */
341    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
342    RESTORE_STACK;
343 }
344
345 /** Decode pulse vector and combine the result with the pitch vector to produce
346     the final normalised signal in the current band. */
347 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
348 {
349    VARDECL(int, iy);
350    SAVE_STACK;
351    ALLOC(iy, N, int);
352    decode_pulses(iy, N, K, dec);
353    mix_pitch_and_residual(iy, X, N, K, P, alpha);
354    RESTORE_STACK;
355 }
356
357 #ifdef FIXED_POINT
358 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
359 #else
360 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
361 #endif
362
363 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
364 {
365    int i,j;
366    int best=0;
367    celt_word32_t best_score=0;
368    celt_word16_t s = 1;
369    int sign;
370    celt_word32_t E;
371    celt_word16_t pred_gain;
372    int max_pos = N0-N/B;
373    if (max_pos > 32)
374       max_pos = 32;
375
376    for (i=0;i<max_pos*B;i+=B)
377    {
378       celt_word32_t xy=0, yy=0;
379       celt_word32_t score;
380       for (j=0;j<N;j++)
381       {
382          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
383          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
384       }
385       score = DIV32(MULT16_16(ROUND(xy,14),ROUND(xy,14)), ROUND(yy,14));
386       if (score > best_score)
387       {
388          best_score = score;
389          best = i;
390          if (xy>0)
391             s = 1;
392          else
393             s = -1;
394       }
395    }
396    if (s<0)
397       sign = 1;
398    else
399       sign = 0;
400    /*printf ("%d %d ", sign, best);*/
401    ec_enc_uint(enc,sign,2);
402    ec_enc_uint(enc,best/B,max_pos);
403    /*printf ("%d %f\n", best, best_score);*/
404    
405    if (K>10)
406       pred_gain = pg[10];
407    else
408       pred_gain = pg[K];
409    E = EPSILON;
410    for (j=0;j<N;j++)
411    {
412       P[j] = s*Y[best+N-j-1];
413       E = MAC16_16(E, P[j],P[j]);
414    }
415    /*pred_gain = pred_gain/sqrt(E);*/
416    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
417    for (j=0;j<N;j++)
418       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
419    if (K>0)
420    {
421       for (j=0;j<N;j++)
422          x[j] -= P[j];
423    } else {
424       for (j=0;j<N;j++)
425          x[j] = P[j];
426    }
427    /*printf ("quant ");*/
428    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
429
430 }
431
432 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
433 {
434    int j;
435    int sign;
436    celt_word16_t s;
437    int best;
438    celt_word32_t E;
439    celt_word16_t pred_gain;
440    int max_pos = N0-N/B;
441    if (max_pos > 32)
442       max_pos = 32;
443    
444    sign = ec_dec_uint(dec, 2);
445    if (sign == 0)
446       s = 1;
447    else
448       s = -1;
449    
450    best = B*ec_dec_uint(dec, max_pos);
451    /*printf ("%d %d ", sign, best);*/
452
453    if (K>10)
454       pred_gain = pg[10];
455    else
456       pred_gain = pg[K];
457    E = EPSILON;
458    for (j=0;j<N;j++)
459    {
460       P[j] = s*Y[best+N-j-1];
461       E = MAC16_16(E, P[j],P[j]);
462    }
463    /*pred_gain = pred_gain/sqrt(E);*/
464    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
465    for (j=0;j<N;j++)
466       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
467    if (K==0)
468    {
469       for (j=0;j<N;j++)
470          x[j] = P[j];
471    }
472 }
473
474 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
475 {
476    int i, j;
477    celt_word32_t E;
478    celt_word16_t g;
479    
480    E = EPSILON;
481    if (N0 >= Nmax/2)
482    {
483       for (i=0;i<B;i++)
484       {
485          for (j=0;j<N/B;j++)
486          {
487             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
488             E += P[j*B+i]*P[j*B+i];
489          }
490       }
491    } else {
492       for (j=0;j<N;j++)
493       {
494          P[j] = Y[j];
495          E = MAC16_16(E, P[j],P[j]);
496       }
497    }
498    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
499    for (j=0;j<N;j++)
500       P[j] = PSHR32(MULT16_16(g, P[j]),8);
501    for (j=0;j<N;j++)
502       x[j] = P[j];
503 }
504