celt_mask_t for masking curves
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 /* Enable this or define your own implementation if you want to speed up the
44    VQ search (used in inner loop only) */
45 #if 0
46 #include <xmmintrin.h>
47 static inline float approx_sqrt(float x)
48 {
49    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
50    return x;
51 }
52 static inline float approx_inv(float x)
53 {
54    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
55    return x;
56 }
57 #else
58 #define approx_sqrt(x) (sqrt(x))
59 #define approx_inv(x) (1.f/(x))
60 #endif
61
62 /** All the info necessary to keep track of a hypothesis during the search */
63 struct NBest {
64    float score;
65    float gain;
66    int sign;
67    int pos;
68    int orig;
69    float xy;
70    float yy;
71    float yp;
72 };
73
74 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, float alpha, ec_enc *enc)
75 {
76    int L = 3;
77    VARDECL(float *x);
78    VARDECL(float *p);
79    VARDECL(float *_y);
80    VARDECL(float *_ny);
81    VARDECL(int *_iy);
82    VARDECL(int *_iny);
83    VARDECL(float **y);
84    VARDECL(float **ny);
85    VARDECL(int **iy);
86    VARDECL(int **iny);
87    int i, j, k, m;
88    int pulsesLeft;
89    VARDECL(float *xy);
90    VARDECL(float *yy);
91    VARDECL(float *yp);
92    VARDECL(struct NBest *_nbest);
93    VARDECL(struct NBest **nbest);
94    float Rpp=0, Rxp=0;
95    int maxL = 1;
96    
97    ALLOC(x, N, float);
98    ALLOC(p, N, float);
99    ALLOC(_y, L*N, float);
100    ALLOC(_ny, L*N, float);
101    ALLOC(_iy, L*N, int);
102    ALLOC(_iny, L*N, int);
103    ALLOC(y, L*N, float*);
104    ALLOC(ny, L*N, float*);
105    ALLOC(iy, L*N, int*);
106    ALLOC(iny, L*N, int*);
107    
108    ALLOC(xy, L, float);
109    ALLOC(yy, L, float);
110    ALLOC(yp, L, float);
111    ALLOC(_nbest, L, struct NBest);
112    ALLOC(nbest, L, struct NBest *);
113
114    for (j=0;j<N;j++)
115    {
116       x[j] = X[j]*NORM_SCALING_1;
117       p[j] = P[j]*NORM_SCALING_1;
118    }
119    
120    for (m=0;m<L;m++)
121       nbest[m] = &_nbest[m];
122    
123    for (m=0;m<L;m++)
124    {
125       ny[m] = &_ny[m*N];
126       iny[m] = &_iny[m*N];
127       y[m] = &_y[m*N];
128       iy[m] = &_iy[m*N];
129    }
130    
131    for (j=0;j<N;j++)
132    {
133       Rpp += p[j]*p[j];
134       Rxp += x[j]*p[j];
135    }
136    if (Rpp>1)
137       celt_fatal("Rpp > 1");
138
139    /* We only need to initialise the zero because the first iteration only uses that */
140    for (i=0;i<N;i++)
141       y[0][i] = 0;
142    for (i=0;i<N;i++)
143       iy[0][i] = 0;
144    xy[0] = yy[0] = yp[0] = 0;
145
146    pulsesLeft = K;
147    while (pulsesLeft > 0)
148    {
149       int pulsesAtOnce=1;
150       int Lupdate = L;
151       int L2 = L;
152       
153       /* Decide on complexity strategy */
154       pulsesAtOnce = pulsesLeft/N;
155       if (pulsesAtOnce<1)
156          pulsesAtOnce = 1;
157       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
158          Lupdate = 1;
159       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
160       L2 = Lupdate;
161       if (L2>maxL)
162       {
163          L2 = maxL;
164          maxL *= N;
165       }
166
167       for (m=0;m<Lupdate;m++)
168          nbest[m]->score = -1e10f;
169
170       for (m=0;m<L2;m++)
171       {
172          for (j=0;j<N;j++)
173          {
174             int sign;
175             /*if (x[j]>0) sign=1; else sign=-1;*/
176             for (sign=-1;sign<=1;sign+=2)
177             {
178                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
179                float tmp_xy, tmp_yy, tmp_yp;
180                float score;
181                float g;
182                float s = sign*pulsesAtOnce;
183                
184                /* All pulses at one location must have the same sign. */
185                if (iy[m][j]*sign < 0)
186                   continue;
187
188                /* Updating the sums of the new pulse(s) */
189                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
190                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
191                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
192                
193                /* Compute the gain such that ||p + g*y|| = 1 */
194                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
195                /* Knowing that gain, what the error: (x-g*y)^2 
196                   (result is negated and we discard x^2 because it's constant) */
197                score = 2.f*g*tmp_xy - g*g*tmp_yy;
198
199                if (score>nbest[Lupdate-1]->score)
200                {
201                   int k;
202                   int id = Lupdate-1;
203                   struct NBest *tmp_best;
204
205                   /* Save some pointers that would be deleted and use them for the current entry*/
206                   tmp_best = nbest[Lupdate-1];
207                   while (id > 0 && score > nbest[id-1]->score)
208                      id--;
209                
210                   for (k=Lupdate-1;k>id;k--)
211                      nbest[k] = nbest[k-1];
212
213                   nbest[id] = tmp_best;
214                   nbest[id]->score = score;
215                   nbest[id]->pos = j;
216                   nbest[id]->orig = m;
217                   nbest[id]->sign = sign;
218                   nbest[id]->gain = g;
219                   nbest[id]->xy = tmp_xy;
220                   nbest[id]->yy = tmp_yy;
221                   nbest[id]->yp = tmp_yp;
222                }
223             }
224          }
225
226       }
227       
228       if (!(nbest[0]->score > -1e10f))
229          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
230       /* Only now that we've made the final choice, update ny/iny and others */
231       for (k=0;k<Lupdate;k++)
232       {
233          int n;
234          int is;
235          float s;
236          is = nbest[k]->sign*pulsesAtOnce;
237          s = is;
238          for (n=0;n<N;n++)
239             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
240          ny[k][nbest[k]->pos] += s;
241
242          for (n=0;n<N;n++)
243             iny[k][n] = iy[nbest[k]->orig][n];
244          iny[k][nbest[k]->pos] += is;
245
246          xy[k] = nbest[k]->xy;
247          yy[k] = nbest[k]->yy;
248          yp[k] = nbest[k]->yp;
249       }
250       /* Swap ny/iny with y/iy */
251       for (k=0;k<Lupdate;k++)
252       {
253          float *tmp_ny;
254          int *tmp_iny;
255
256          tmp_ny = ny[k];
257          ny[k] = y[k];
258          y[k] = tmp_ny;
259          tmp_iny = iny[k];
260          iny[k] = iy[k];
261          iy[k] = tmp_iny;
262       }
263       pulsesLeft -= pulsesAtOnce;
264    }
265    
266    if (0) {
267       float err=0;
268       for (i=0;i<N;i++)
269          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
270       /*if (N<=10)
271         printf ("%f %d %d\n", err, K, N);*/
272    }
273    for (i=0;i<N;i++)
274       x[i] = p[i]+nbest[0]->gain*y[0][i];
275    /* Sanity checks, don't bother */
276    if (0) {
277       float E=1e-15;
278       int ABS = 0;
279       for (i=0;i<N;i++)
280          ABS += abs(iy[0][i]);
281       /*if (K != ABS)
282          printf ("%d %d\n", K, ABS);*/
283       for (i=0;i<N;i++)
284          E += x[i]*x[i];
285       /*printf ("%f\n", E);*/
286       E = 1/sqrt(E);
287       for (i=0;i<N;i++)
288          x[i] *= E;
289    }
290    
291    encode_pulses(iy[0], N, K, enc);
292    
293    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
294       due to the recursive computation used in quantisation.
295       Not quite sure whether we need that or not */
296    if (1) {
297       float Ryp=0;
298       float Ryy=0;
299       float g=0;
300       
301       for (i=0;i<N;i++)
302          Ryp += iy[0][i]*p[i];
303       
304       for (i=0;i<N;i++)
305          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
306       
307       Ryp = 0;
308       for (i=0;i<N;i++)
309          Ryp += y[0][i]*p[i];
310       
311       for (i=0;i<N;i++)
312          Ryy += y[0][i]*y[0][i];
313       
314       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
315         
316       for (i=0;i<N;i++)
317          x[i] = p[i] + g*y[0][i];
318       
319    }
320    for (j=0;j<N;j++)
321    {
322       X[j] = x[j] * NORM_SCALING;
323       P[j] = p[j] * NORM_SCALING;
324    }
325
326 }
327
328 /** Decode pulse vector and combine the result with the pitch vector to produce
329     the final normalised signal in the current band. */
330 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, float alpha, ec_dec *dec)
331 {
332    int i;
333    float Rpp=0, Ryp=0, Ryy=0;
334    float g;
335    VARDECL(int *iy);
336    VARDECL(float *y);
337    VARDECL(float *x);
338    VARDECL(float *p);
339    
340    ALLOC(iy, N, int);
341    ALLOC(y, N, float);
342    ALLOC(x, N, float);
343    ALLOC(p, N, float);
344
345    decode_pulses(iy, N, K, dec);
346    for (i=0;i<N;i++)
347    {
348       x[i] = X[i]*NORM_SCALING_1;
349       p[i] = P[i]*NORM_SCALING_1;
350    }
351
352    /*for (i=0;i<N;i++)
353       printf ("%d ", iy[i]);*/
354    for (i=0;i<N;i++)
355       Rpp += p[i]*p[i];
356
357    for (i=0;i<N;i++)
358       Ryp += iy[i]*p[i];
359
360    for (i=0;i<N;i++)
361       y[i] = iy[i] - alpha*Ryp*p[i];
362
363    /* Recompute after the projection (I think it's right) */
364    Ryp = 0;
365    for (i=0;i<N;i++)
366       Ryp += y[i]*p[i];
367
368    for (i=0;i<N;i++)
369       Ryy += y[i]*y[i];
370
371    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
372
373    for (i=0;i<N;i++)
374       x[i] = p[i] + g*y[i];
375    for (i=0;i<N;i++)
376    {
377       X[i] = x[i] * NORM_SCALING;
378       P[i] = p[i] * NORM_SCALING;
379    }
380
381 }
382
383
384 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
385
386 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
387 {
388    int i,j;
389    int best=0;
390    float best_score=0;
391    float s = 1;
392    int sign;
393    float E;
394    float pred_gain;
395    int max_pos = N0-N/B;
396    if (max_pos > 32)
397       max_pos = 32;
398
399    for (i=0;i<max_pos*B;i+=B)
400    {
401       int j;
402       float xy=0, yy=0;
403       float score;
404       for (j=0;j<N;j++)
405       {
406          xy += 1.f*x[j]*Y[i+N-j-1];
407          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
408       }
409       score = xy*xy/(.001+yy);
410       if (score > best_score)
411       {
412          best_score = score;
413          best = i;
414          if (xy>0)
415             s = 1;
416          else
417             s = -1;
418       }
419    }
420    if (s<0)
421       sign = 1;
422    else
423       sign = 0;
424    /*printf ("%d %d ", sign, best);*/
425    ec_enc_uint(enc,sign,2);
426    ec_enc_uint(enc,best/B,max_pos);
427    /*printf ("%d %f\n", best, best_score);*/
428    
429    if (K>10)
430       pred_gain = pg[10];
431    else
432       pred_gain = pg[K];
433    E = 1e-10;
434    for (j=0;j<N;j++)
435    {
436       P[j] = s*Y[best+N-j-1];
437       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
438    }
439    E = pred_gain/sqrt(E);
440    for (j=0;j<N;j++)
441       P[j] *= E;
442    if (K>0)
443    {
444       for (j=0;j<N;j++)
445          x[j] -= P[j];
446    } else {
447       for (j=0;j<N;j++)
448          x[j] = P[j];
449    }
450    /*printf ("quant ");*/
451    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
452
453 }
454
455 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
456 {
457    int j;
458    int sign;
459    float s;
460    int best;
461    float E;
462    float pred_gain;
463    int max_pos = N0-N/B;
464    if (max_pos > 32)
465       max_pos = 32;
466    
467    sign = ec_dec_uint(dec, 2);
468    if (sign == 0)
469       s = 1;
470    else
471       s = -1;
472    
473    best = B*ec_dec_uint(dec, max_pos);
474    /*printf ("%d %d ", sign, best);*/
475
476    if (K>10)
477       pred_gain = pg[10];
478    else
479       pred_gain = pg[K];
480    E = 1e-10;
481    for (j=0;j<N;j++)
482    {
483       P[j] = s*Y[best+N-j-1];
484       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
485    }
486    E = pred_gain/sqrt(E);
487    for (j=0;j<N;j++)
488       P[j] *= E;
489    if (K==0)
490    {
491       for (j=0;j<N;j++)
492          x[j] = P[j];
493    }
494 }
495
496 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
497 {
498    int i, j;
499    float E;
500    
501    E = 1e-10;
502    if (N0 >= Nmax/2)
503    {
504       for (i=0;i<B;i++)
505       {
506          for (j=0;j<N/B;j++)
507          {
508             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
509             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
510          }
511       }
512    } else {
513       for (j=0;j<N;j++)
514       {
515          P[j] = Y[j];
516          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
517       }
518    }
519    E = 1.f/sqrt(E);
520    for (j=0;j<N;j++)
521       P[j] *= E;
522    for (j=0;j<N;j++)
523       x[j] = P[j];
524 }
525