fixed-point: mix_pitch_and_residual() check-point #3
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 /* Enable this or define your own implementation if you want to speed up the
44    VQ search (used in inner loop only) */
45 #if 0
46 #include <xmmintrin.h>
47 static inline float approx_sqrt(float x)
48 {
49    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
50    return x;
51 }
52 static inline float approx_inv(float x)
53 {
54    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
55    return x;
56 }
57 #else
58 #define approx_sqrt(x) (sqrt(x))
59 #define approx_inv(x) (1.f/(x))
60 #endif
61
62 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
63    applies the compression in the pitch direction, computes the gain that will
64    give ||p+g*y||=1 and mixes the residual with the pitch. */
65 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
66 {
67    int i;
68    celt_word32_t Ryp, Ryy, Rpp;
69    float g;
70    VARDECL(celt_norm_t *y);
71 #ifdef FIXED_POINT
72    int yshift = 14-EC_ILOG(K);
73 #endif
74    ALLOC(y, N, celt_norm_t);
75
76    /*for (i=0;i<N;i++)
77    printf ("%d ", iy[i]);*/
78    Rpp = 0;
79    for (i=0;i<N;i++)
80       Rpp = MAC16_16(Rpp,P[i],P[i]);
81
82    Ryp = 0;
83    for (i=0;i<N;i++)
84       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
85
86    /* Remove part of the pitch component to compute the real residual from
87       the encoded (int) one */
88    for (i=0;i<N;i++)
89       y[i] = SUB16(SHL16(iy[i],yshift),
90                    MULT16_16_Q15(alpha,MULT16_16_Q14(EXTRACT16(SHR32(Ryp,14)),P[i])));
91
92    /* Recompute after the projection (I think it's right) */
93    Ryp = 0;
94    for (i=0;i<N;i++)
95       Ryp = MAC16_16(Ryp,y[i],P[i]);
96
97    Ryy = 0;
98    for (i=0;i<N;i++)
99       Ryy = MAC16_16(Ryy, y[i],y[i]);
100
101    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
102    g = (sqrt(MULT16_16(PSHR32(Ryp,14),PSHR32(Ryp,14)) + Ryy - MULT16_16(PSHR32(Ryy,14),PSHR32(Rpp,14))) - PSHR32(Ryp,14))/Ryy;
103
104    for (i=0;i<N;i++)
105       X[i] = P[i] + NORM_SCALING*g*y[i];
106 }
107
108 /** All the info necessary to keep track of a hypothesis during the search */
109 struct NBest {
110    float score;
111    float gain;
112    int sign;
113    int pos;
114    int orig;
115    float xy;
116    float yy;
117    float yp;
118 };
119
120 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
121 {
122    int L = 3;
123    VARDECL(float *x);
124    VARDECL(float *p);
125    VARDECL(float *_y);
126    VARDECL(float *_ny);
127    VARDECL(int *_iy);
128    VARDECL(int *_iny);
129    VARDECL(float **y);
130    VARDECL(float **ny);
131    VARDECL(int **iy);
132    VARDECL(int **iny);
133    int i, j, k, m;
134    int pulsesLeft;
135    VARDECL(float *xy);
136    VARDECL(float *yy);
137    VARDECL(float *yp);
138    VARDECL(struct NBest *_nbest);
139    VARDECL(struct NBest **nbest);
140    float Rpp=0, Rxp=0;
141    int maxL = 1;
142    float _alpha = Q15_ONE_1*alpha;
143
144    ALLOC(x, N, float);
145    ALLOC(p, N, float);
146    ALLOC(_y, L*N, float);
147    ALLOC(_ny, L*N, float);
148    ALLOC(_iy, L*N, int);
149    ALLOC(_iny, L*N, int);
150    ALLOC(y, L*N, float*);
151    ALLOC(ny, L*N, float*);
152    ALLOC(iy, L*N, int*);
153    ALLOC(iny, L*N, int*);
154    
155    ALLOC(xy, L, float);
156    ALLOC(yy, L, float);
157    ALLOC(yp, L, float);
158    ALLOC(_nbest, L, struct NBest);
159    ALLOC(nbest, L, struct NBest *);
160
161    for (j=0;j<N;j++)
162    {
163       x[j] = X[j]*NORM_SCALING_1;
164       p[j] = P[j]*NORM_SCALING_1;
165    }
166    
167    for (m=0;m<L;m++)
168       nbest[m] = &_nbest[m];
169    
170    for (m=0;m<L;m++)
171    {
172       ny[m] = &_ny[m*N];
173       iny[m] = &_iny[m*N];
174       y[m] = &_y[m*N];
175       iy[m] = &_iy[m*N];
176    }
177    
178    for (j=0;j<N;j++)
179    {
180       Rpp += p[j]*p[j];
181       Rxp += x[j]*p[j];
182    }
183    if (Rpp>1)
184       celt_fatal("Rpp > 1");
185
186    /* We only need to initialise the zero because the first iteration only uses that */
187    for (i=0;i<N;i++)
188       y[0][i] = 0;
189    for (i=0;i<N;i++)
190       iy[0][i] = 0;
191    xy[0] = yy[0] = yp[0] = 0;
192
193    pulsesLeft = K;
194    while (pulsesLeft > 0)
195    {
196       int pulsesAtOnce=1;
197       int Lupdate = L;
198       int L2 = L;
199       
200       /* Decide on complexity strategy */
201       pulsesAtOnce = pulsesLeft/N;
202       if (pulsesAtOnce<1)
203          pulsesAtOnce = 1;
204       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
205          Lupdate = 1;
206       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
207       L2 = Lupdate;
208       if (L2>maxL)
209       {
210          L2 = maxL;
211          maxL *= N;
212       }
213
214       for (m=0;m<Lupdate;m++)
215          nbest[m]->score = -1e10f;
216
217       for (m=0;m<L2;m++)
218       {
219          for (j=0;j<N;j++)
220          {
221             int sign;
222             /*if (x[j]>0) sign=1; else sign=-1;*/
223             for (sign=-1;sign<=1;sign+=2)
224             {
225                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
226                float tmp_xy, tmp_yy, tmp_yp;
227                float score;
228                float g;
229                float s = sign*pulsesAtOnce;
230                
231                /* All pulses at one location must have the same sign. */
232                if (iy[m][j]*sign < 0)
233                   continue;
234
235                /* Updating the sums of the new pulse(s) */
236                tmp_xy = xy[m] + s*x[j]               - _alpha*s*p[j]*Rxp;
237                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp - 2.f*_alpha*s*p[j]*yp[m] - 2.f*s*s*_alpha*p[j]*p[j];
238                tmp_yp = yp[m] + s*p[j]               *(1.f-_alpha*Rpp);
239                
240                /* Compute the gain such that ||p + g*y|| = 1 */
241                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
242                /* Knowing that gain, what the error: (x-g*y)^2 
243                   (result is negated and we discard x^2 because it's constant) */
244                score = 2.f*g*tmp_xy - g*g*tmp_yy;
245
246                if (score>nbest[Lupdate-1]->score)
247                {
248                   int k;
249                   int id = Lupdate-1;
250                   struct NBest *tmp_best;
251
252                   /* Save some pointers that would be deleted and use them for the current entry*/
253                   tmp_best = nbest[Lupdate-1];
254                   while (id > 0 && score > nbest[id-1]->score)
255                      id--;
256                
257                   for (k=Lupdate-1;k>id;k--)
258                      nbest[k] = nbest[k-1];
259
260                   nbest[id] = tmp_best;
261                   nbest[id]->score = score;
262                   nbest[id]->pos = j;
263                   nbest[id]->orig = m;
264                   nbest[id]->sign = sign;
265                   nbest[id]->gain = g;
266                   nbest[id]->xy = tmp_xy;
267                   nbest[id]->yy = tmp_yy;
268                   nbest[id]->yp = tmp_yp;
269                }
270             }
271          }
272
273       }
274       
275       if (!(nbest[0]->score > -1e10f))
276          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
277       /* Only now that we've made the final choice, update ny/iny and others */
278       for (k=0;k<Lupdate;k++)
279       {
280          int n;
281          int is;
282          float s;
283          is = nbest[k]->sign*pulsesAtOnce;
284          s = is;
285          for (n=0;n<N;n++)
286             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
287          ny[k][nbest[k]->pos] += s;
288
289          for (n=0;n<N;n++)
290             iny[k][n] = iy[nbest[k]->orig][n];
291          iny[k][nbest[k]->pos] += is;
292
293          xy[k] = nbest[k]->xy;
294          yy[k] = nbest[k]->yy;
295          yp[k] = nbest[k]->yp;
296       }
297       /* Swap ny/iny with y/iy */
298       for (k=0;k<Lupdate;k++)
299       {
300          float *tmp_ny;
301          int *tmp_iny;
302
303          tmp_ny = ny[k];
304          ny[k] = y[k];
305          y[k] = tmp_ny;
306          tmp_iny = iny[k];
307          iny[k] = iy[k];
308          iy[k] = tmp_iny;
309       }
310       pulsesLeft -= pulsesAtOnce;
311    }
312    
313 #if 0
314    if (0) {
315       float err=0;
316       for (i=0;i<N;i++)
317          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
318       /*if (N<=10)
319         printf ("%f %d %d\n", err, K, N);*/
320    }
321    /* Sanity checks, don't bother */
322    if (0) {
323       for (i=0;i<N;i++)
324          x[i] = p[i]+nbest[0]->gain*y[0][i];
325       float E=1e-15;
326       int ABS = 0;
327       for (i=0;i<N;i++)
328          ABS += abs(iy[0][i]);
329       /*if (K != ABS)
330          printf ("%d %d\n", K, ABS);*/
331       for (i=0;i<N;i++)
332          E += x[i]*x[i];
333       /*printf ("%f\n", E);*/
334       E = 1/sqrt(E);
335       for (i=0;i<N;i++)
336          x[i] *= E;
337    }
338 #endif
339    
340    encode_pulses(iy[0], N, K, enc);
341    
342    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
343       due to the recursive computation used in quantisation.
344       Not quite sure whether we need that or not */
345    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
346 }
347
348 /** Decode pulse vector and combine the result with the pitch vector to produce
349     the final normalised signal in the current band. */
350 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
351 {
352    VARDECL(int *iy);
353    ALLOC(iy, N, int);
354    decode_pulses(iy, N, K, dec);
355    mix_pitch_and_residual(iy, X, N, K, P, alpha);
356 }
357
358
359 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
360
361 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
362 {
363    int i,j;
364    int best=0;
365    float best_score=0;
366    float s = 1;
367    int sign;
368    float E;
369    float pred_gain;
370    int max_pos = N0-N/B;
371    if (max_pos > 32)
372       max_pos = 32;
373
374    for (i=0;i<max_pos*B;i+=B)
375    {
376       int j;
377       float xy=0, yy=0;
378       float score;
379       for (j=0;j<N;j++)
380       {
381          xy += 1.f*x[j]*Y[i+N-j-1];
382          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
383       }
384       score = xy*xy/(.001+yy);
385       if (score > best_score)
386       {
387          best_score = score;
388          best = i;
389          if (xy>0)
390             s = 1;
391          else
392             s = -1;
393       }
394    }
395    if (s<0)
396       sign = 1;
397    else
398       sign = 0;
399    /*printf ("%d %d ", sign, best);*/
400    ec_enc_uint(enc,sign,2);
401    ec_enc_uint(enc,best/B,max_pos);
402    /*printf ("%d %f\n", best, best_score);*/
403    
404    if (K>10)
405       pred_gain = pg[10];
406    else
407       pred_gain = pg[K];
408    E = 1e-10;
409    for (j=0;j<N;j++)
410    {
411       P[j] = s*Y[best+N-j-1];
412       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
413    }
414    E = pred_gain/sqrt(E);
415    for (j=0;j<N;j++)
416       P[j] *= E;
417    if (K>0)
418    {
419       for (j=0;j<N;j++)
420          x[j] -= P[j];
421    } else {
422       for (j=0;j<N;j++)
423          x[j] = P[j];
424    }
425    /*printf ("quant ");*/
426    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
427
428 }
429
430 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
431 {
432    int j;
433    int sign;
434    float s;
435    int best;
436    float E;
437    float pred_gain;
438    int max_pos = N0-N/B;
439    if (max_pos > 32)
440       max_pos = 32;
441    
442    sign = ec_dec_uint(dec, 2);
443    if (sign == 0)
444       s = 1;
445    else
446       s = -1;
447    
448    best = B*ec_dec_uint(dec, max_pos);
449    /*printf ("%d %d ", sign, best);*/
450
451    if (K>10)
452       pred_gain = pg[10];
453    else
454       pred_gain = pg[K];
455    E = 1e-10;
456    for (j=0;j<N;j++)
457    {
458       P[j] = s*Y[best+N-j-1];
459       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
460    }
461    E = pred_gain/sqrt(E);
462    for (j=0;j<N;j++)
463       P[j] *= E;
464    if (K==0)
465    {
466       for (j=0;j<N;j++)
467          x[j] = P[j];
468    }
469 }
470
471 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
472 {
473    int i, j;
474    float E;
475    
476    E = 1e-10;
477    if (N0 >= Nmax/2)
478    {
479       for (i=0;i<B;i++)
480       {
481          for (j=0;j<N/B;j++)
482          {
483             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
484             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
485          }
486       }
487    } else {
488       for (j=0;j<N;j++)
489       {
490          P[j] = Y[j];
491          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
492       }
493    }
494    E = 1.f/sqrt(E);
495    for (j=0;j<N;j++)
496       P[j] *= E;
497    for (j=0;j<N;j++)
498       x[j] = P[j];
499 }
500