fixed-point: done with mix_pitch_and_residual() though a bit of cleaning up
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 /* Enable this or define your own implementation if you want to speed up the
44    VQ search (used in inner loop only) */
45 #if 0
46 #include <xmmintrin.h>
47 static inline float approx_sqrt(float x)
48 {
49    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
50    return x;
51 }
52 static inline float approx_inv(float x)
53 {
54    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
55    return x;
56 }
57 #else
58 #define approx_sqrt(x) (sqrt(x))
59 #define approx_inv(x) (1.f/(x))
60 #endif
61
62 #ifdef FIXED_POINT
63 #define C0 3634
64 #define C1 21173
65 #define C2 -12627
66 #define C3 4204
67
68 static inline celt_word32_t celt_sqrt(celt_word32_t x)
69 {
70    int k;
71    //printf ("%d ", x);
72    celt_word32_t rt;
73    /* ((EC_ILOG(x)-1)>>1) is just the int log4(x) (EC_ILOG returns log2 + 1) */
74    k = ((EC_ILOG(x)-1)>>1)-6;
75    x = VSHR32(x, (k<<1));
76    rt = ADD16(C0, MULT16_16_Q14(x, ADD16(C1, MULT16_16_Q14(x, ADD16(C2, MULT16_16_Q14(x, (C3)))))));
77    rt = VSHR32(rt,7-k);
78    //printf ("%d\n", rt);
79    return rt;
80 }
81 #else
82 #define celt_sqrt sqrt
83 #endif
84
85 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
86    applies the compression in the pitch direction, computes the gain that will
87    give ||p+g*y||=1 and mixes the residual with the pitch. */
88 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
89 {
90    int i;
91    celt_word32_t Ryp, Ryy, Rpp;
92    celt_word32_t g;
93    VARDECL(celt_norm_t *y);
94 #ifdef FIXED_POINT
95    int yshift = 14-EC_ILOG(K);
96 #endif
97    ALLOC(y, N, celt_norm_t);
98
99    /*for (i=0;i<N;i++)
100    printf ("%d ", iy[i]);*/
101    Rpp = 0;
102    for (i=0;i<N;i++)
103       Rpp = MAC16_16(Rpp,P[i],P[i]);
104
105    Ryp = 0;
106    for (i=0;i<N;i++)
107       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
108
109    /* Remove part of the pitch component to compute the real residual from
110       the encoded (int) one */
111    for (i=0;i<N;i++)
112       y[i] = SUB16(SHL16(iy[i],yshift),
113                    MULT16_16_Q15(alpha,MULT16_16_Q14(EXTRACT16(SHR32(Ryp,14)),P[i])));
114
115    /* Recompute after the projection (I think it's right) */
116    Ryp = 0;
117    for (i=0;i<N;i++)
118       Ryp = MAC16_16(Ryp,y[i],P[i]);
119
120    Ryy = 0;
121    for (i=0;i<N;i++)
122       Ryy = MAC16_16(Ryy, y[i],y[i]);
123
124    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
125    g = DIV32(SHL32(celt_sqrt(MULT16_16(PSHR32(Ryp,14),PSHR32(Ryp,14)) + Ryy - MULT16_16(PSHR32(Ryy,14),PSHR32(Rpp,14))) - PSHR32(Ryp,14),14),PSHR32(Ryy,14));
126
127    for (i=0;i<N;i++)
128       X[i] = P[i] + MULT16_32_Q14(y[i], g);
129 }
130
131 /** All the info necessary to keep track of a hypothesis during the search */
132 struct NBest {
133    float score;
134    float gain;
135    int sign;
136    int pos;
137    int orig;
138    float xy;
139    float yy;
140    float yp;
141 };
142
143 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
144 {
145    int L = 3;
146    VARDECL(float *x);
147    VARDECL(float *p);
148    VARDECL(float *_y);
149    VARDECL(float *_ny);
150    VARDECL(int *_iy);
151    VARDECL(int *_iny);
152    VARDECL(float **y);
153    VARDECL(float **ny);
154    VARDECL(int **iy);
155    VARDECL(int **iny);
156    int i, j, k, m;
157    int pulsesLeft;
158    VARDECL(float *xy);
159    VARDECL(float *yy);
160    VARDECL(float *yp);
161    VARDECL(struct NBest *_nbest);
162    VARDECL(struct NBest **nbest);
163    float Rpp=0, Rxp=0;
164    int maxL = 1;
165    float _alpha = Q15_ONE_1*alpha;
166
167    ALLOC(x, N, float);
168    ALLOC(p, N, float);
169    ALLOC(_y, L*N, float);
170    ALLOC(_ny, L*N, float);
171    ALLOC(_iy, L*N, int);
172    ALLOC(_iny, L*N, int);
173    ALLOC(y, L*N, float*);
174    ALLOC(ny, L*N, float*);
175    ALLOC(iy, L*N, int*);
176    ALLOC(iny, L*N, int*);
177    
178    ALLOC(xy, L, float);
179    ALLOC(yy, L, float);
180    ALLOC(yp, L, float);
181    ALLOC(_nbest, L, struct NBest);
182    ALLOC(nbest, L, struct NBest *);
183
184    for (j=0;j<N;j++)
185    {
186       x[j] = X[j]*NORM_SCALING_1;
187       p[j] = P[j]*NORM_SCALING_1;
188    }
189    
190    for (m=0;m<L;m++)
191       nbest[m] = &_nbest[m];
192    
193    for (m=0;m<L;m++)
194    {
195       ny[m] = &_ny[m*N];
196       iny[m] = &_iny[m*N];
197       y[m] = &_y[m*N];
198       iy[m] = &_iy[m*N];
199    }
200    
201    for (j=0;j<N;j++)
202    {
203       Rpp += p[j]*p[j];
204       Rxp += x[j]*p[j];
205    }
206    if (Rpp>1)
207       celt_fatal("Rpp > 1");
208
209    /* We only need to initialise the zero because the first iteration only uses that */
210    for (i=0;i<N;i++)
211       y[0][i] = 0;
212    for (i=0;i<N;i++)
213       iy[0][i] = 0;
214    xy[0] = yy[0] = yp[0] = 0;
215
216    pulsesLeft = K;
217    while (pulsesLeft > 0)
218    {
219       int pulsesAtOnce=1;
220       int Lupdate = L;
221       int L2 = L;
222       
223       /* Decide on complexity strategy */
224       pulsesAtOnce = pulsesLeft/N;
225       if (pulsesAtOnce<1)
226          pulsesAtOnce = 1;
227       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
228          Lupdate = 1;
229       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
230       L2 = Lupdate;
231       if (L2>maxL)
232       {
233          L2 = maxL;
234          maxL *= N;
235       }
236
237       for (m=0;m<Lupdate;m++)
238          nbest[m]->score = -1e10f;
239
240       for (m=0;m<L2;m++)
241       {
242          for (j=0;j<N;j++)
243          {
244             int sign;
245             /*if (x[j]>0) sign=1; else sign=-1;*/
246             for (sign=-1;sign<=1;sign+=2)
247             {
248                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
249                float tmp_xy, tmp_yy, tmp_yp;
250                float score;
251                float g;
252                float s = sign*pulsesAtOnce;
253                
254                /* All pulses at one location must have the same sign. */
255                if (iy[m][j]*sign < 0)
256                   continue;
257
258                /* Updating the sums of the new pulse(s) */
259                tmp_xy = xy[m] + s*x[j]               - _alpha*s*p[j]*Rxp;
260                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp - 2.f*_alpha*s*p[j]*yp[m] - 2.f*s*s*_alpha*p[j]*p[j];
261                tmp_yp = yp[m] + s*p[j]               *(1.f-_alpha*Rpp);
262                
263                /* Compute the gain such that ||p + g*y|| = 1 */
264                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
265                /* Knowing that gain, what the error: (x-g*y)^2 
266                   (result is negated and we discard x^2 because it's constant) */
267                score = 2.f*g*tmp_xy - g*g*tmp_yy;
268
269                if (score>nbest[Lupdate-1]->score)
270                {
271                   int k;
272                   int id = Lupdate-1;
273                   struct NBest *tmp_best;
274
275                   /* Save some pointers that would be deleted and use them for the current entry*/
276                   tmp_best = nbest[Lupdate-1];
277                   while (id > 0 && score > nbest[id-1]->score)
278                      id--;
279                
280                   for (k=Lupdate-1;k>id;k--)
281                      nbest[k] = nbest[k-1];
282
283                   nbest[id] = tmp_best;
284                   nbest[id]->score = score;
285                   nbest[id]->pos = j;
286                   nbest[id]->orig = m;
287                   nbest[id]->sign = sign;
288                   nbest[id]->gain = g;
289                   nbest[id]->xy = tmp_xy;
290                   nbest[id]->yy = tmp_yy;
291                   nbest[id]->yp = tmp_yp;
292                }
293             }
294          }
295
296       }
297       
298       if (!(nbest[0]->score > -1e10f))
299          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
300       /* Only now that we've made the final choice, update ny/iny and others */
301       for (k=0;k<Lupdate;k++)
302       {
303          int n;
304          int is;
305          float s;
306          is = nbest[k]->sign*pulsesAtOnce;
307          s = is;
308          for (n=0;n<N;n++)
309             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
310          ny[k][nbest[k]->pos] += s;
311
312          for (n=0;n<N;n++)
313             iny[k][n] = iy[nbest[k]->orig][n];
314          iny[k][nbest[k]->pos] += is;
315
316          xy[k] = nbest[k]->xy;
317          yy[k] = nbest[k]->yy;
318          yp[k] = nbest[k]->yp;
319       }
320       /* Swap ny/iny with y/iy */
321       for (k=0;k<Lupdate;k++)
322       {
323          float *tmp_ny;
324          int *tmp_iny;
325
326          tmp_ny = ny[k];
327          ny[k] = y[k];
328          y[k] = tmp_ny;
329          tmp_iny = iny[k];
330          iny[k] = iy[k];
331          iy[k] = tmp_iny;
332       }
333       pulsesLeft -= pulsesAtOnce;
334    }
335    
336 #if 0
337    if (0) {
338       float err=0;
339       for (i=0;i<N;i++)
340          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
341       /*if (N<=10)
342         printf ("%f %d %d\n", err, K, N);*/
343    }
344    /* Sanity checks, don't bother */
345    if (0) {
346       for (i=0;i<N;i++)
347          x[i] = p[i]+nbest[0]->gain*y[0][i];
348       float E=1e-15;
349       int ABS = 0;
350       for (i=0;i<N;i++)
351          ABS += abs(iy[0][i]);
352       /*if (K != ABS)
353          printf ("%d %d\n", K, ABS);*/
354       for (i=0;i<N;i++)
355          E += x[i]*x[i];
356       /*printf ("%f\n", E);*/
357       E = 1/sqrt(E);
358       for (i=0;i<N;i++)
359          x[i] *= E;
360    }
361 #endif
362    
363    encode_pulses(iy[0], N, K, enc);
364    
365    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
366       due to the recursive computation used in quantisation.
367       Not quite sure whether we need that or not */
368    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
369 }
370
371 /** Decode pulse vector and combine the result with the pitch vector to produce
372     the final normalised signal in the current band. */
373 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
374 {
375    VARDECL(int *iy);
376    ALLOC(iy, N, int);
377    decode_pulses(iy, N, K, dec);
378    mix_pitch_and_residual(iy, X, N, K, P, alpha);
379 }
380
381
382 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
383
384 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
385 {
386    int i,j;
387    int best=0;
388    float best_score=0;
389    float s = 1;
390    int sign;
391    float E;
392    float pred_gain;
393    int max_pos = N0-N/B;
394    if (max_pos > 32)
395       max_pos = 32;
396
397    for (i=0;i<max_pos*B;i+=B)
398    {
399       int j;
400       float xy=0, yy=0;
401       float score;
402       for (j=0;j<N;j++)
403       {
404          xy += 1.f*x[j]*Y[i+N-j-1];
405          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
406       }
407       score = xy*xy/(.001+yy);
408       if (score > best_score)
409       {
410          best_score = score;
411          best = i;
412          if (xy>0)
413             s = 1;
414          else
415             s = -1;
416       }
417    }
418    if (s<0)
419       sign = 1;
420    else
421       sign = 0;
422    /*printf ("%d %d ", sign, best);*/
423    ec_enc_uint(enc,sign,2);
424    ec_enc_uint(enc,best/B,max_pos);
425    /*printf ("%d %f\n", best, best_score);*/
426    
427    if (K>10)
428       pred_gain = pg[10];
429    else
430       pred_gain = pg[K];
431    E = 1e-10;
432    for (j=0;j<N;j++)
433    {
434       P[j] = s*Y[best+N-j-1];
435       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
436    }
437    E = pred_gain/sqrt(E);
438    for (j=0;j<N;j++)
439       P[j] *= E;
440    if (K>0)
441    {
442       for (j=0;j<N;j++)
443          x[j] -= P[j];
444    } else {
445       for (j=0;j<N;j++)
446          x[j] = P[j];
447    }
448    /*printf ("quant ");*/
449    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
450
451 }
452
453 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
454 {
455    int j;
456    int sign;
457    float s;
458    int best;
459    float E;
460    float pred_gain;
461    int max_pos = N0-N/B;
462    if (max_pos > 32)
463       max_pos = 32;
464    
465    sign = ec_dec_uint(dec, 2);
466    if (sign == 0)
467       s = 1;
468    else
469       s = -1;
470    
471    best = B*ec_dec_uint(dec, max_pos);
472    /*printf ("%d %d ", sign, best);*/
473
474    if (K>10)
475       pred_gain = pg[10];
476    else
477       pred_gain = pg[K];
478    E = 1e-10;
479    for (j=0;j<N;j++)
480    {
481       P[j] = s*Y[best+N-j-1];
482       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
483    }
484    E = pred_gain/sqrt(E);
485    for (j=0;j<N;j++)
486       P[j] *= E;
487    if (K==0)
488    {
489       for (j=0;j<N;j++)
490          x[j] = P[j];
491    }
492 }
493
494 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
495 {
496    int i, j;
497    float E;
498    
499    E = 1e-10;
500    if (N0 >= Nmax/2)
501    {
502       for (i=0;i<B;i++)
503       {
504          for (j=0;j<N/B;j++)
505          {
506             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
507             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
508          }
509       }
510    } else {
511       for (j=0;j<N;j++)
512       {
513          P[j] = Y[j];
514          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
515       }
516    }
517    E = 1.f/sqrt(E);
518    for (j=0;j<N;j++)
519       P[j] *= E;
520    for (j=0;j<N;j++)
521       x[j] = P[j];
522 }
523