A bunch of const qualifyers and a few comments
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72    SAVE_STACK;
73 #ifdef FIXED_POINT
74    int yshift = 14-EC_ILOG(K);
75 #endif
76    ALLOC(y, N, celt_norm_t);
77
78    /*for (i=0;i<N;i++)
79    printf ("%d ", iy[i]);*/
80    Rpp = 0;
81    for (i=0;i<N;i++)
82       Rpp = MAC16_16(Rpp,P[i],P[i]);
83
84    Ryp = 0;
85    for (i=0;i<N;i++)
86       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
87
88    /* Remove part of the pitch component to compute the real residual from
89       the encoded (int) one */
90    for (i=0;i<N;i++)
91       y[i] = SUB16(SHL16(iy[i],yshift),
92                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
93
94    /* Recompute after the projection (I think it's right) */
95    Ryp = 0;
96    for (i=0;i<N;i++)
97       Ryp = MAC16_16(Ryp,y[i],P[i]);
98
99    Ryy = 0;
100    for (i=0;i<N;i++)
101       Ryy = MAC16_16(Ryy, y[i],y[i]);
102
103    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
104    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
105
106    for (i=0;i<N;i++)
107       X[i] = P[i] + MULT16_32_Q14(y[i], g);
108    RESTORE_STACK;
109 }
110
111 /** All the info necessary to keep track of a hypothesis during the search */
112 struct NBest {
113    celt_word32_t score;
114    int sign;
115    int pos;
116    int orig;
117    celt_word32_t xy;
118    celt_word32_t yy;
119    celt_word32_t yp;
120 };
121
122 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
123 {
124    int L = 3;
125    VARDECL(celt_norm_t *_y);
126    VARDECL(celt_norm_t *_ny);
127    VARDECL(int *_iy);
128    VARDECL(int *_iny);
129    VARDECL(celt_norm_t **y);
130    VARDECL(celt_norm_t **ny);
131    VARDECL(int **iy);
132    VARDECL(int **iny);
133    int i, j, k, m;
134    int pulsesLeft;
135    VARDECL(celt_word32_t *xy);
136    VARDECL(celt_word32_t *yy);
137    VARDECL(celt_word32_t *yp);
138    VARDECL(struct NBest *_nbest);
139    VARDECL(struct NBest **nbest);
140    celt_word32_t Rpp=0, Rxp=0;
141    int maxL = 1;
142 #ifdef FIXED_POINT
143    int yshift;
144 #endif
145    SAVE_STACK;
146
147 #ifdef FIXED_POINT
148    yshift = 14-EC_ILOG(K);
149 #endif
150
151    ALLOC(_y, L*N, celt_norm_t);
152    ALLOC(_ny, L*N, celt_norm_t);
153    ALLOC(_iy, L*N, int);
154    ALLOC(_iny, L*N, int);
155    ALLOC(y, L, celt_norm_t*);
156    ALLOC(ny, L, celt_norm_t*);
157    ALLOC(iy, L, int*);
158    ALLOC(iny, L, int*);
159    
160    ALLOC(xy, L, celt_word32_t);
161    ALLOC(yy, L, celt_word32_t);
162    ALLOC(yp, L, celt_word32_t);
163    ALLOC(_nbest, L, struct NBest);
164    ALLOC(nbest, L, struct NBest *);
165    
166    for (m=0;m<L;m++)
167       nbest[m] = &_nbest[m];
168    
169    for (m=0;m<L;m++)
170    {
171       ny[m] = &_ny[m*N];
172       iny[m] = &_iny[m*N];
173       y[m] = &_y[m*N];
174       iy[m] = &_iy[m*N];
175    }
176    
177    for (j=0;j<N;j++)
178    {
179       Rpp = MAC16_16(Rpp, P[j],P[j]);
180       Rxp = MAC16_16(Rxp, X[j],P[j]);
181    }
182    Rpp = ROUND(Rpp, NORM_SHIFT);
183    Rxp = ROUND(Rxp, NORM_SHIFT);
184    if (Rpp>NORM_SCALING)
185       celt_fatal("Rpp > 1");
186
187    /* We only need to initialise the zero because the first iteration only uses that */
188    for (i=0;i<N;i++)
189       y[0][i] = 0;
190    for (i=0;i<N;i++)
191       iy[0][i] = 0;
192    xy[0] = yy[0] = yp[0] = 0;
193
194    pulsesLeft = K;
195    while (pulsesLeft > 0)
196    {
197       int pulsesAtOnce=1;
198       int Lupdate = L;
199       int L2 = L;
200       
201       /* Decide on complexity strategy */
202       pulsesAtOnce = pulsesLeft/N;
203       if (pulsesAtOnce<1)
204          pulsesAtOnce = 1;
205       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
206          Lupdate = 1;
207       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
208       L2 = Lupdate;
209       if (L2>maxL)
210       {
211          L2 = maxL;
212          maxL *= N;
213       }
214
215       for (m=0;m<Lupdate;m++)
216          nbest[m]->score = -VERY_LARGE32;
217
218       for (m=0;m<L2;m++)
219       {
220          for (j=0;j<N;j++)
221          {
222             int sign;
223             /*if (x[j]>0) sign=1; else sign=-1;*/
224             for (sign=-1;sign<=1;sign+=2)
225             {
226                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
227                celt_word32_t Rxy, Ryy, Ryp;
228                celt_word16_t spj, aspj; /* Intermediate results */
229                celt_word32_t score;
230                celt_word32_t g;
231                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
232                
233                /* All pulses at one location must have the same sign. */
234                if (iy[m][j]*sign < 0)
235                   continue;
236
237                spj = MULT16_16_P14(s, P[j]);
238                aspj = MULT16_16_P15(alpha, spj);
239                /* Updating the sums of the new pulse(s) */
240                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
241                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
242                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
243                
244                /* Compute the gain such that ||p + g*y|| = 1 */
245                g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),Rpp)) - ROUND(Ryp,14),14),ROUND(Ryy,14));
246                
247                /* Knowing that gain, what the error: (x-g*y)^2 
248                   (result is negated and we discard x^2 because it's constant) */
249                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
250                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g) -
251                      MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
252
253                if (score>nbest[Lupdate-1]->score)
254                {
255                   int k;
256                   int id = Lupdate-1;
257                   struct NBest *tmp_best;
258
259                   /* Save some pointers that would be deleted and use them for the current entry*/
260                   tmp_best = nbest[Lupdate-1];
261                   while (id > 0 && score > nbest[id-1]->score)
262                      id--;
263                
264                   for (k=Lupdate-1;k>id;k--)
265                      nbest[k] = nbest[k-1];
266
267                   nbest[id] = tmp_best;
268                   nbest[id]->score = score;
269                   nbest[id]->pos = j;
270                   nbest[id]->orig = m;
271                   nbest[id]->sign = sign;
272                   nbest[id]->xy = Rxy;
273                   nbest[id]->yy = Ryy;
274                   nbest[id]->yp = Ryp;
275                }
276             }
277          }
278
279       }
280       
281       if (!(nbest[0]->score > -VERY_LARGE32))
282          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
283       /* Only now that we've made the final choice, update ny/iny and others */
284       for (k=0;k<Lupdate;k++)
285       {
286          int n;
287          int is;
288          celt_norm_t s;
289          is = nbest[k]->sign*pulsesAtOnce;
290          s = SHL16(is, yshift);
291          for (n=0;n<N;n++)
292             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
293          ny[k][nbest[k]->pos] += s;
294
295          for (n=0;n<N;n++)
296             iny[k][n] = iy[nbest[k]->orig][n];
297          iny[k][nbest[k]->pos] += is;
298
299          xy[k] = nbest[k]->xy;
300          yy[k] = nbest[k]->yy;
301          yp[k] = nbest[k]->yp;
302       }
303       /* Swap ny/iny with y/iy */
304       for (k=0;k<Lupdate;k++)
305       {
306          celt_norm_t *tmp_ny;
307          int *tmp_iny;
308
309          tmp_ny = ny[k];
310          ny[k] = y[k];
311          y[k] = tmp_ny;
312          tmp_iny = iny[k];
313          iny[k] = iy[k];
314          iy[k] = tmp_iny;
315       }
316       pulsesLeft -= pulsesAtOnce;
317    }
318    
319 #if 0
320    if (0) {
321       celt_word32_t err=0;
322       for (i=0;i<N;i++)
323          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
324       /*if (N<=10)
325         printf ("%f %d %d\n", err, K, N);*/
326    }
327    /* Sanity checks, don't bother */
328    if (0) {
329       for (i=0;i<N;i++)
330          x[i] = p[i]+nbest[0]->gain*y[0][i];
331       celt_word32_t E=1e-15;
332       int ABS = 0;
333       for (i=0;i<N;i++)
334          ABS += abs(iy[0][i]);
335       /*if (K != ABS)
336          printf ("%d %d\n", K, ABS);*/
337       for (i=0;i<N;i++)
338          E += x[i]*x[i];
339       /*printf ("%f\n", E);*/
340       E = 1/sqrt(E);
341       for (i=0;i<N;i++)
342          x[i] *= E;
343    }
344 #endif
345    
346    encode_pulses(iy[0], N, K, enc);
347    
348    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
349       due to the recursive computation used in quantisation.
350       Not quite sure whether we need that or not */
351    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
352    RESTORE_STACK;
353 }
354
355 /** Decode pulse vector and combine the result with the pitch vector to produce
356     the final normalised signal in the current band. */
357 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
358 {
359    VARDECL(int *iy);
360    SAVE_STACK;
361    ALLOC(iy, N, int);
362    decode_pulses(iy, N, K, dec);
363    mix_pitch_and_residual(iy, X, N, K, P, alpha);
364    RESTORE_STACK;
365 }
366
367
368 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
369
370 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
371 {
372    int i,j;
373    int best=0;
374    celt_word32_t best_score=0;
375    celt_word16_t s = 1;
376    int sign;
377    celt_word32_t E;
378    float pred_gain;
379    int max_pos = N0-N/B;
380    if (max_pos > 32)
381       max_pos = 32;
382
383    for (i=0;i<max_pos*B;i+=B)
384    {
385       int j;
386       celt_word32_t xy=0, yy=0;
387       float score;
388       for (j=0;j<N;j++)
389       {
390          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
391          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
392       }
393       score = 1.f*xy*xy/(.001+yy);
394       if (score > best_score)
395       {
396          best_score = score;
397          best = i;
398          if (xy>0)
399             s = 1;
400          else
401             s = -1;
402       }
403    }
404    if (s<0)
405       sign = 1;
406    else
407       sign = 0;
408    /*printf ("%d %d ", sign, best);*/
409    ec_enc_uint(enc,sign,2);
410    ec_enc_uint(enc,best/B,max_pos);
411    /*printf ("%d %f\n", best, best_score);*/
412    
413    if (K>10)
414       pred_gain = pg[10];
415    else
416       pred_gain = pg[K];
417    E = 1e-10;
418    for (j=0;j<N;j++)
419    {
420       P[j] = s*Y[best+N-j-1];
421       E = MAC16_16(E, P[j],P[j]);
422    }
423    pred_gain = NORM_SCALING*pred_gain/sqrt(E);
424    for (j=0;j<N;j++)
425       P[j] *= pred_gain;
426    if (K>0)
427    {
428       for (j=0;j<N;j++)
429          x[j] -= P[j];
430    } else {
431       for (j=0;j<N;j++)
432          x[j] = P[j];
433    }
434    /*printf ("quant ");*/
435    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
436
437 }
438
439 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
440 {
441    int j;
442    int sign;
443    celt_word16_t s;
444    int best;
445    celt_word32_t E;
446    float pred_gain;
447    int max_pos = N0-N/B;
448    if (max_pos > 32)
449       max_pos = 32;
450    
451    sign = ec_dec_uint(dec, 2);
452    if (sign == 0)
453       s = 1;
454    else
455       s = -1;
456    
457    best = B*ec_dec_uint(dec, max_pos);
458    /*printf ("%d %d ", sign, best);*/
459
460    if (K>10)
461       pred_gain = pg[10];
462    else
463       pred_gain = pg[K];
464    E = 1e-10;
465    for (j=0;j<N;j++)
466    {
467       P[j] = s*Y[best+N-j-1];
468       E = MAC16_16(E, P[j],P[j]);
469    }
470    pred_gain = NORM_SCALING*pred_gain/sqrt(E);
471    for (j=0;j<N;j++)
472       P[j] *= pred_gain;
473    if (K==0)
474    {
475       for (j=0;j<N;j++)
476          x[j] = P[j];
477    }
478 }
479
480 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
481 {
482    int i, j;
483    celt_word32_t E;
484    float g;
485    
486    E = 1e-10;
487    if (N0 >= Nmax/2)
488    {
489       for (i=0;i<B;i++)
490       {
491          for (j=0;j<N/B;j++)
492          {
493             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
494             E += P[j*B+i]*P[j*B+i];
495          }
496       }
497    } else {
498       for (j=0;j<N;j++)
499       {
500          P[j] = Y[j];
501          E = MAC16_16(E, P[j],P[j]);
502       }
503    }
504    g = NORM_SCALING/sqrt(E);
505    for (j=0;j<N;j++)
506       P[j] *= g;
507    for (j=0;j<N;j++)
508       x[j] = P[j];
509 }
510