fixed-point: another mix_pitch_and_residual() check-point
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 /* Enable this or define your own implementation if you want to speed up the
44    VQ search (used in inner loop only) */
45 #if 0
46 #include <xmmintrin.h>
47 static inline float approx_sqrt(float x)
48 {
49    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
50    return x;
51 }
52 static inline float approx_inv(float x)
53 {
54    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
55    return x;
56 }
57 #else
58 #define approx_sqrt(x) (sqrt(x))
59 #define approx_inv(x) (1.f/(x))
60 #endif
61
62 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
63    applies the compression in the pitch direction, computes the gain that will
64    give ||p+g*y||=1 and mixes the residual with the pitch. */
65 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
66 {
67    int i;
68    celt_word32_t Ryp, Ryy, Rpp;
69    float g;
70    VARDECL(celt_norm_t *y);
71 #ifdef FIXED_POINT
72    int yshift = 15-EC_ILOG(K);
73 #endif
74    ALLOC(y, N, celt_norm_t);
75
76    /*for (i=0;i<N;i++)
77    printf ("%d ", iy[i]);*/
78    Rpp = 0;
79    for (i=0;i<N;i++)
80       Rpp = MAC16_16(Rpp,P[i],P[i]);
81
82    Ryp = 0;
83    for (i=0;i<N;i++)
84       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
85
86    /* Remove part of the pitch component to compute the real residual from
87       the encoded (int) one */
88    for (i=0;i<N;i++)
89       y[i] = SUB16(SHL16(iy[i],yshift),
90                    MULT16_16_Q15(alpha,MULT16_16_Q14(EXTRACT16(SHR32(Ryp,14)),P[i])));
91
92    /* Recompute after the projection (I think it's right) */
93    Ryp = 0;
94    for (i=0;i<N;i++)
95       Ryp = MAC16_16(Ryp,y[i],P[i]);
96
97    Ryy = 0;
98    for (i=0;i<N;i++)
99       Ryy = MAC16_16(Ryy, y[i],y[i]);
100
101    g = (sqrt(NORM_SCALING_1*NORM_SCALING_1*Ryp*Ryp + Ryy - NORM_SCALING_1*NORM_SCALING_1*Ryy*Rpp) - NORM_SCALING_1*Ryp)/Ryy;
102
103    for (i=0;i<N;i++)
104       X[i] = P[i] + NORM_SCALING*g*y[i];
105 }
106
107 /** All the info necessary to keep track of a hypothesis during the search */
108 struct NBest {
109    float score;
110    float gain;
111    int sign;
112    int pos;
113    int orig;
114    float xy;
115    float yy;
116    float yp;
117 };
118
119 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
120 {
121    int L = 3;
122    VARDECL(float *x);
123    VARDECL(float *p);
124    VARDECL(float *_y);
125    VARDECL(float *_ny);
126    VARDECL(int *_iy);
127    VARDECL(int *_iny);
128    VARDECL(float **y);
129    VARDECL(float **ny);
130    VARDECL(int **iy);
131    VARDECL(int **iny);
132    int i, j, k, m;
133    int pulsesLeft;
134    VARDECL(float *xy);
135    VARDECL(float *yy);
136    VARDECL(float *yp);
137    VARDECL(struct NBest *_nbest);
138    VARDECL(struct NBest **nbest);
139    float Rpp=0, Rxp=0;
140    int maxL = 1;
141    float _alpha = Q15_ONE_1*alpha;
142
143    ALLOC(x, N, float);
144    ALLOC(p, N, float);
145    ALLOC(_y, L*N, float);
146    ALLOC(_ny, L*N, float);
147    ALLOC(_iy, L*N, int);
148    ALLOC(_iny, L*N, int);
149    ALLOC(y, L*N, float*);
150    ALLOC(ny, L*N, float*);
151    ALLOC(iy, L*N, int*);
152    ALLOC(iny, L*N, int*);
153    
154    ALLOC(xy, L, float);
155    ALLOC(yy, L, float);
156    ALLOC(yp, L, float);
157    ALLOC(_nbest, L, struct NBest);
158    ALLOC(nbest, L, struct NBest *);
159
160    for (j=0;j<N;j++)
161    {
162       x[j] = X[j]*NORM_SCALING_1;
163       p[j] = P[j]*NORM_SCALING_1;
164    }
165    
166    for (m=0;m<L;m++)
167       nbest[m] = &_nbest[m];
168    
169    for (m=0;m<L;m++)
170    {
171       ny[m] = &_ny[m*N];
172       iny[m] = &_iny[m*N];
173       y[m] = &_y[m*N];
174       iy[m] = &_iy[m*N];
175    }
176    
177    for (j=0;j<N;j++)
178    {
179       Rpp += p[j]*p[j];
180       Rxp += x[j]*p[j];
181    }
182    if (Rpp>1)
183       celt_fatal("Rpp > 1");
184
185    /* We only need to initialise the zero because the first iteration only uses that */
186    for (i=0;i<N;i++)
187       y[0][i] = 0;
188    for (i=0;i<N;i++)
189       iy[0][i] = 0;
190    xy[0] = yy[0] = yp[0] = 0;
191
192    pulsesLeft = K;
193    while (pulsesLeft > 0)
194    {
195       int pulsesAtOnce=1;
196       int Lupdate = L;
197       int L2 = L;
198       
199       /* Decide on complexity strategy */
200       pulsesAtOnce = pulsesLeft/N;
201       if (pulsesAtOnce<1)
202          pulsesAtOnce = 1;
203       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
204          Lupdate = 1;
205       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
206       L2 = Lupdate;
207       if (L2>maxL)
208       {
209          L2 = maxL;
210          maxL *= N;
211       }
212
213       for (m=0;m<Lupdate;m++)
214          nbest[m]->score = -1e10f;
215
216       for (m=0;m<L2;m++)
217       {
218          for (j=0;j<N;j++)
219          {
220             int sign;
221             /*if (x[j]>0) sign=1; else sign=-1;*/
222             for (sign=-1;sign<=1;sign+=2)
223             {
224                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
225                float tmp_xy, tmp_yy, tmp_yp;
226                float score;
227                float g;
228                float s = sign*pulsesAtOnce;
229                
230                /* All pulses at one location must have the same sign. */
231                if (iy[m][j]*sign < 0)
232                   continue;
233
234                /* Updating the sums of the new pulse(s) */
235                tmp_xy = xy[m] + s*x[j]               - _alpha*s*p[j]*Rxp;
236                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp - 2.f*_alpha*s*p[j]*yp[m] - 2.f*s*s*_alpha*p[j]*p[j];
237                tmp_yp = yp[m] + s*p[j]               *(1.f-_alpha*Rpp);
238                
239                /* Compute the gain such that ||p + g*y|| = 1 */
240                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
241                /* Knowing that gain, what the error: (x-g*y)^2 
242                   (result is negated and we discard x^2 because it's constant) */
243                score = 2.f*g*tmp_xy - g*g*tmp_yy;
244
245                if (score>nbest[Lupdate-1]->score)
246                {
247                   int k;
248                   int id = Lupdate-1;
249                   struct NBest *tmp_best;
250
251                   /* Save some pointers that would be deleted and use them for the current entry*/
252                   tmp_best = nbest[Lupdate-1];
253                   while (id > 0 && score > nbest[id-1]->score)
254                      id--;
255                
256                   for (k=Lupdate-1;k>id;k--)
257                      nbest[k] = nbest[k-1];
258
259                   nbest[id] = tmp_best;
260                   nbest[id]->score = score;
261                   nbest[id]->pos = j;
262                   nbest[id]->orig = m;
263                   nbest[id]->sign = sign;
264                   nbest[id]->gain = g;
265                   nbest[id]->xy = tmp_xy;
266                   nbest[id]->yy = tmp_yy;
267                   nbest[id]->yp = tmp_yp;
268                }
269             }
270          }
271
272       }
273       
274       if (!(nbest[0]->score > -1e10f))
275          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
276       /* Only now that we've made the final choice, update ny/iny and others */
277       for (k=0;k<Lupdate;k++)
278       {
279          int n;
280          int is;
281          float s;
282          is = nbest[k]->sign*pulsesAtOnce;
283          s = is;
284          for (n=0;n<N;n++)
285             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
286          ny[k][nbest[k]->pos] += s;
287
288          for (n=0;n<N;n++)
289             iny[k][n] = iy[nbest[k]->orig][n];
290          iny[k][nbest[k]->pos] += is;
291
292          xy[k] = nbest[k]->xy;
293          yy[k] = nbest[k]->yy;
294          yp[k] = nbest[k]->yp;
295       }
296       /* Swap ny/iny with y/iy */
297       for (k=0;k<Lupdate;k++)
298       {
299          float *tmp_ny;
300          int *tmp_iny;
301
302          tmp_ny = ny[k];
303          ny[k] = y[k];
304          y[k] = tmp_ny;
305          tmp_iny = iny[k];
306          iny[k] = iy[k];
307          iy[k] = tmp_iny;
308       }
309       pulsesLeft -= pulsesAtOnce;
310    }
311    
312 #if 0
313    if (0) {
314       float err=0;
315       for (i=0;i<N;i++)
316          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
317       /*if (N<=10)
318         printf ("%f %d %d\n", err, K, N);*/
319    }
320    /* Sanity checks, don't bother */
321    if (0) {
322       for (i=0;i<N;i++)
323          x[i] = p[i]+nbest[0]->gain*y[0][i];
324       float E=1e-15;
325       int ABS = 0;
326       for (i=0;i<N;i++)
327          ABS += abs(iy[0][i]);
328       /*if (K != ABS)
329          printf ("%d %d\n", K, ABS);*/
330       for (i=0;i<N;i++)
331          E += x[i]*x[i];
332       /*printf ("%f\n", E);*/
333       E = 1/sqrt(E);
334       for (i=0;i<N;i++)
335          x[i] *= E;
336    }
337 #endif
338    
339    encode_pulses(iy[0], N, K, enc);
340    
341    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
342       due to the recursive computation used in quantisation.
343       Not quite sure whether we need that or not */
344    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
345 }
346
347 /** Decode pulse vector and combine the result with the pitch vector to produce
348     the final normalised signal in the current band. */
349 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
350 {
351    VARDECL(int *iy);
352    ALLOC(iy, N, int);
353    decode_pulses(iy, N, K, dec);
354    mix_pitch_and_residual(iy, X, N, K, P, alpha);
355 }
356
357
358 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
359
360 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
361 {
362    int i,j;
363    int best=0;
364    float best_score=0;
365    float s = 1;
366    int sign;
367    float E;
368    float pred_gain;
369    int max_pos = N0-N/B;
370    if (max_pos > 32)
371       max_pos = 32;
372
373    for (i=0;i<max_pos*B;i+=B)
374    {
375       int j;
376       float xy=0, yy=0;
377       float score;
378       for (j=0;j<N;j++)
379       {
380          xy += 1.f*x[j]*Y[i+N-j-1];
381          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
382       }
383       score = xy*xy/(.001+yy);
384       if (score > best_score)
385       {
386          best_score = score;
387          best = i;
388          if (xy>0)
389             s = 1;
390          else
391             s = -1;
392       }
393    }
394    if (s<0)
395       sign = 1;
396    else
397       sign = 0;
398    /*printf ("%d %d ", sign, best);*/
399    ec_enc_uint(enc,sign,2);
400    ec_enc_uint(enc,best/B,max_pos);
401    /*printf ("%d %f\n", best, best_score);*/
402    
403    if (K>10)
404       pred_gain = pg[10];
405    else
406       pred_gain = pg[K];
407    E = 1e-10;
408    for (j=0;j<N;j++)
409    {
410       P[j] = s*Y[best+N-j-1];
411       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
412    }
413    E = pred_gain/sqrt(E);
414    for (j=0;j<N;j++)
415       P[j] *= E;
416    if (K>0)
417    {
418       for (j=0;j<N;j++)
419          x[j] -= P[j];
420    } else {
421       for (j=0;j<N;j++)
422          x[j] = P[j];
423    }
424    /*printf ("quant ");*/
425    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
426
427 }
428
429 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
430 {
431    int j;
432    int sign;
433    float s;
434    int best;
435    float E;
436    float pred_gain;
437    int max_pos = N0-N/B;
438    if (max_pos > 32)
439       max_pos = 32;
440    
441    sign = ec_dec_uint(dec, 2);
442    if (sign == 0)
443       s = 1;
444    else
445       s = -1;
446    
447    best = B*ec_dec_uint(dec, max_pos);
448    /*printf ("%d %d ", sign, best);*/
449
450    if (K>10)
451       pred_gain = pg[10];
452    else
453       pred_gain = pg[K];
454    E = 1e-10;
455    for (j=0;j<N;j++)
456    {
457       P[j] = s*Y[best+N-j-1];
458       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
459    }
460    E = pred_gain/sqrt(E);
461    for (j=0;j<N;j++)
462       P[j] *= E;
463    if (K==0)
464    {
465       for (j=0;j<N;j++)
466          x[j] = P[j];
467    }
468 }
469
470 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
471 {
472    int i, j;
473    float E;
474    
475    E = 1e-10;
476    if (N0 >= Nmax/2)
477    {
478       for (i=0;i<B;i++)
479       {
480          for (j=0;j<N/B;j++)
481          {
482             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
483             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
484          }
485       }
486    } else {
487       for (j=0;j<N;j++)
488       {
489          P[j] = Y[j];
490          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
491       }
492    }
493    E = 1.f/sqrt(E);
494    for (j=0;j<N;j++)
495       P[j] *= E;
496    for (j=0;j<N;j++)
497       x[j] = P[j];
498 }
499