fixed-point: half-way converting mix_pitch_and_residual() -- just check-pointing
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 /* Enable this or define your own implementation if you want to speed up the
44    VQ search (used in inner loop only) */
45 #if 0
46 #include <xmmintrin.h>
47 static inline float approx_sqrt(float x)
48 {
49    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
50    return x;
51 }
52 static inline float approx_inv(float x)
53 {
54    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
55    return x;
56 }
57 #else
58 #define approx_sqrt(x) (sqrt(x))
59 #define approx_inv(x) (1.f/(x))
60 #endif
61
62 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
63    applies the compression in the pitch direction, computes the gain that will
64    give ||p+g*y||=1 and mixes the residual with the pitch. */
65 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
66 {
67    int i;
68    float Rpp=0, Ryp=0, Ryy=0;
69    celt_word32_t Ryp2;
70    float g;
71    VARDECL(celt_norm_t *y);
72    VARDECL(float *x);
73    VARDECL(float *p);
74    float _alpha = Q15_ONE_1*alpha;
75 #ifdef FIXED_POINT
76    int yshift = 15-EC_ILOG(K);
77 #endif
78    ALLOC(y, N, celt_norm_t);
79    ALLOC(x, N, float);
80    ALLOC(p, N, float);
81
82    for (i=0;i<N;i++)
83       p[i] = P[i]*NORM_SCALING_1;
84
85    /*for (i=0;i<N;i++)
86    printf ("%d ", iy[i]);*/
87    for (i=0;i<N;i++)
88       Rpp += p[i]*p[i];
89
90    Ryp2 = 0;
91    for (i=0;i<N;i++)
92       Ryp2 += MULT16_16(SHL16(iy[i],yshift),P[i]);
93
94    /* Remove part of the pitch component to compute the real residual from
95       the encoded (int) one */
96    for (i=0;i<N;i++)
97       y[i] = SUB16(SHL16(iy[i],yshift),
98                    MULT16_16_Q15(alpha,MULT16_16_Q14(EXTRACT16(SHR32(Ryp2,14)),P[i])));
99
100    /* Recompute after the projection (I think it's right) */
101    Ryp = 0;
102    for (i=0;i<N;i++)
103       Ryp += y[i]*p[i];
104
105    for (i=0;i<N;i++)
106       Ryy += y[i]*y[i];
107
108    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
109
110    for (i=0;i<N;i++)
111       X[i] = P[i] + NORM_SCALING*g*y[i];
112 }
113
114 /** All the info necessary to keep track of a hypothesis during the search */
115 struct NBest {
116    float score;
117    float gain;
118    int sign;
119    int pos;
120    int orig;
121    float xy;
122    float yy;
123    float yp;
124 };
125
126 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
127 {
128    int L = 3;
129    VARDECL(float *x);
130    VARDECL(float *p);
131    VARDECL(float *_y);
132    VARDECL(float *_ny);
133    VARDECL(int *_iy);
134    VARDECL(int *_iny);
135    VARDECL(float **y);
136    VARDECL(float **ny);
137    VARDECL(int **iy);
138    VARDECL(int **iny);
139    int i, j, k, m;
140    int pulsesLeft;
141    VARDECL(float *xy);
142    VARDECL(float *yy);
143    VARDECL(float *yp);
144    VARDECL(struct NBest *_nbest);
145    VARDECL(struct NBest **nbest);
146    float Rpp=0, Rxp=0;
147    int maxL = 1;
148    float _alpha = Q15_ONE_1*alpha;
149
150    ALLOC(x, N, float);
151    ALLOC(p, N, float);
152    ALLOC(_y, L*N, float);
153    ALLOC(_ny, L*N, float);
154    ALLOC(_iy, L*N, int);
155    ALLOC(_iny, L*N, int);
156    ALLOC(y, L*N, float*);
157    ALLOC(ny, L*N, float*);
158    ALLOC(iy, L*N, int*);
159    ALLOC(iny, L*N, int*);
160    
161    ALLOC(xy, L, float);
162    ALLOC(yy, L, float);
163    ALLOC(yp, L, float);
164    ALLOC(_nbest, L, struct NBest);
165    ALLOC(nbest, L, struct NBest *);
166
167    for (j=0;j<N;j++)
168    {
169       x[j] = X[j]*NORM_SCALING_1;
170       p[j] = P[j]*NORM_SCALING_1;
171    }
172    
173    for (m=0;m<L;m++)
174       nbest[m] = &_nbest[m];
175    
176    for (m=0;m<L;m++)
177    {
178       ny[m] = &_ny[m*N];
179       iny[m] = &_iny[m*N];
180       y[m] = &_y[m*N];
181       iy[m] = &_iy[m*N];
182    }
183    
184    for (j=0;j<N;j++)
185    {
186       Rpp += p[j]*p[j];
187       Rxp += x[j]*p[j];
188    }
189    if (Rpp>1)
190       celt_fatal("Rpp > 1");
191
192    /* We only need to initialise the zero because the first iteration only uses that */
193    for (i=0;i<N;i++)
194       y[0][i] = 0;
195    for (i=0;i<N;i++)
196       iy[0][i] = 0;
197    xy[0] = yy[0] = yp[0] = 0;
198
199    pulsesLeft = K;
200    while (pulsesLeft > 0)
201    {
202       int pulsesAtOnce=1;
203       int Lupdate = L;
204       int L2 = L;
205       
206       /* Decide on complexity strategy */
207       pulsesAtOnce = pulsesLeft/N;
208       if (pulsesAtOnce<1)
209          pulsesAtOnce = 1;
210       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
211          Lupdate = 1;
212       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
213       L2 = Lupdate;
214       if (L2>maxL)
215       {
216          L2 = maxL;
217          maxL *= N;
218       }
219
220       for (m=0;m<Lupdate;m++)
221          nbest[m]->score = -1e10f;
222
223       for (m=0;m<L2;m++)
224       {
225          for (j=0;j<N;j++)
226          {
227             int sign;
228             /*if (x[j]>0) sign=1; else sign=-1;*/
229             for (sign=-1;sign<=1;sign+=2)
230             {
231                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
232                float tmp_xy, tmp_yy, tmp_yp;
233                float score;
234                float g;
235                float s = sign*pulsesAtOnce;
236                
237                /* All pulses at one location must have the same sign. */
238                if (iy[m][j]*sign < 0)
239                   continue;
240
241                /* Updating the sums of the new pulse(s) */
242                tmp_xy = xy[m] + s*x[j]               - _alpha*s*p[j]*Rxp;
243                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp - 2.f*_alpha*s*p[j]*yp[m] - 2.f*s*s*_alpha*p[j]*p[j];
244                tmp_yp = yp[m] + s*p[j]               *(1.f-_alpha*Rpp);
245                
246                /* Compute the gain such that ||p + g*y|| = 1 */
247                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
248                /* Knowing that gain, what the error: (x-g*y)^2 
249                   (result is negated and we discard x^2 because it's constant) */
250                score = 2.f*g*tmp_xy - g*g*tmp_yy;
251
252                if (score>nbest[Lupdate-1]->score)
253                {
254                   int k;
255                   int id = Lupdate-1;
256                   struct NBest *tmp_best;
257
258                   /* Save some pointers that would be deleted and use them for the current entry*/
259                   tmp_best = nbest[Lupdate-1];
260                   while (id > 0 && score > nbest[id-1]->score)
261                      id--;
262                
263                   for (k=Lupdate-1;k>id;k--)
264                      nbest[k] = nbest[k-1];
265
266                   nbest[id] = tmp_best;
267                   nbest[id]->score = score;
268                   nbest[id]->pos = j;
269                   nbest[id]->orig = m;
270                   nbest[id]->sign = sign;
271                   nbest[id]->gain = g;
272                   nbest[id]->xy = tmp_xy;
273                   nbest[id]->yy = tmp_yy;
274                   nbest[id]->yp = tmp_yp;
275                }
276             }
277          }
278
279       }
280       
281       if (!(nbest[0]->score > -1e10f))
282          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
283       /* Only now that we've made the final choice, update ny/iny and others */
284       for (k=0;k<Lupdate;k++)
285       {
286          int n;
287          int is;
288          float s;
289          is = nbest[k]->sign*pulsesAtOnce;
290          s = is;
291          for (n=0;n<N;n++)
292             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
293          ny[k][nbest[k]->pos] += s;
294
295          for (n=0;n<N;n++)
296             iny[k][n] = iy[nbest[k]->orig][n];
297          iny[k][nbest[k]->pos] += is;
298
299          xy[k] = nbest[k]->xy;
300          yy[k] = nbest[k]->yy;
301          yp[k] = nbest[k]->yp;
302       }
303       /* Swap ny/iny with y/iy */
304       for (k=0;k<Lupdate;k++)
305       {
306          float *tmp_ny;
307          int *tmp_iny;
308
309          tmp_ny = ny[k];
310          ny[k] = y[k];
311          y[k] = tmp_ny;
312          tmp_iny = iny[k];
313          iny[k] = iy[k];
314          iy[k] = tmp_iny;
315       }
316       pulsesLeft -= pulsesAtOnce;
317    }
318    
319 #if 0
320    if (0) {
321       float err=0;
322       for (i=0;i<N;i++)
323          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
324       /*if (N<=10)
325         printf ("%f %d %d\n", err, K, N);*/
326    }
327    /* Sanity checks, don't bother */
328    if (0) {
329       for (i=0;i<N;i++)
330          x[i] = p[i]+nbest[0]->gain*y[0][i];
331       float E=1e-15;
332       int ABS = 0;
333       for (i=0;i<N;i++)
334          ABS += abs(iy[0][i]);
335       /*if (K != ABS)
336          printf ("%d %d\n", K, ABS);*/
337       for (i=0;i<N;i++)
338          E += x[i]*x[i];
339       /*printf ("%f\n", E);*/
340       E = 1/sqrt(E);
341       for (i=0;i<N;i++)
342          x[i] *= E;
343    }
344 #endif
345    
346    encode_pulses(iy[0], N, K, enc);
347    
348    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
349       due to the recursive computation used in quantisation.
350       Not quite sure whether we need that or not */
351    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
352 }
353
354 /** Decode pulse vector and combine the result with the pitch vector to produce
355     the final normalised signal in the current band. */
356 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
357 {
358    VARDECL(int *iy);
359    ALLOC(iy, N, int);
360    decode_pulses(iy, N, K, dec);
361    mix_pitch_and_residual(iy, X, N, K, P, alpha);
362 }
363
364
365 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
366
367 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
368 {
369    int i,j;
370    int best=0;
371    float best_score=0;
372    float s = 1;
373    int sign;
374    float E;
375    float pred_gain;
376    int max_pos = N0-N/B;
377    if (max_pos > 32)
378       max_pos = 32;
379
380    for (i=0;i<max_pos*B;i+=B)
381    {
382       int j;
383       float xy=0, yy=0;
384       float score;
385       for (j=0;j<N;j++)
386       {
387          xy += 1.f*x[j]*Y[i+N-j-1];
388          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
389       }
390       score = xy*xy/(.001+yy);
391       if (score > best_score)
392       {
393          best_score = score;
394          best = i;
395          if (xy>0)
396             s = 1;
397          else
398             s = -1;
399       }
400    }
401    if (s<0)
402       sign = 1;
403    else
404       sign = 0;
405    /*printf ("%d %d ", sign, best);*/
406    ec_enc_uint(enc,sign,2);
407    ec_enc_uint(enc,best/B,max_pos);
408    /*printf ("%d %f\n", best, best_score);*/
409    
410    if (K>10)
411       pred_gain = pg[10];
412    else
413       pred_gain = pg[K];
414    E = 1e-10;
415    for (j=0;j<N;j++)
416    {
417       P[j] = s*Y[best+N-j-1];
418       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
419    }
420    E = pred_gain/sqrt(E);
421    for (j=0;j<N;j++)
422       P[j] *= E;
423    if (K>0)
424    {
425       for (j=0;j<N;j++)
426          x[j] -= P[j];
427    } else {
428       for (j=0;j<N;j++)
429          x[j] = P[j];
430    }
431    /*printf ("quant ");*/
432    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
433
434 }
435
436 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
437 {
438    int j;
439    int sign;
440    float s;
441    int best;
442    float E;
443    float pred_gain;
444    int max_pos = N0-N/B;
445    if (max_pos > 32)
446       max_pos = 32;
447    
448    sign = ec_dec_uint(dec, 2);
449    if (sign == 0)
450       s = 1;
451    else
452       s = -1;
453    
454    best = B*ec_dec_uint(dec, max_pos);
455    /*printf ("%d %d ", sign, best);*/
456
457    if (K>10)
458       pred_gain = pg[10];
459    else
460       pred_gain = pg[K];
461    E = 1e-10;
462    for (j=0;j<N;j++)
463    {
464       P[j] = s*Y[best+N-j-1];
465       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
466    }
467    E = pred_gain/sqrt(E);
468    for (j=0;j<N;j++)
469       P[j] *= E;
470    if (K==0)
471    {
472       for (j=0;j<N;j++)
473          x[j] = P[j];
474    }
475 }
476
477 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
478 {
479    int i, j;
480    float E;
481    
482    E = 1e-10;
483    if (N0 >= Nmax/2)
484    {
485       for (i=0;i<B;i++)
486       {
487          for (j=0;j<N/B;j++)
488          {
489             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
490             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
491          }
492       }
493    } else {
494       for (j=0;j<N;j++)
495       {
496          P[j] = Y[j];
497          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
498       }
499    }
500    E = 1.f/sqrt(E);
501    for (j=0;j<N;j++)
502       P[j] *= E;
503    for (j=0;j<N;j++)
504       x[j] = P[j];
505 }
506