fixed-point: second check-point in the conversion of alg_quant()
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift = 14-EC_ILOG(K);
74 #endif
75    ALLOC(y, N, celt_norm_t);
76
77    /*for (i=0;i<N;i++)
78    printf ("%d ", iy[i]);*/
79    Rpp = 0;
80    for (i=0;i<N;i++)
81       Rpp = MAC16_16(Rpp,P[i],P[i]);
82
83    Ryp = 0;
84    for (i=0;i<N;i++)
85       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
86
87    /* Remove part of the pitch component to compute the real residual from
88       the encoded (int) one */
89    for (i=0;i<N;i++)
90       y[i] = SUB16(SHL16(iy[i],yshift),
91                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
92
93    /* Recompute after the projection (I think it's right) */
94    Ryp = 0;
95    for (i=0;i<N;i++)
96       Ryp = MAC16_16(Ryp,y[i],P[i]);
97
98    Ryy = 0;
99    for (i=0;i<N;i++)
100       Ryy = MAC16_16(Ryy, y[i],y[i]);
101
102    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
103    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
104
105    for (i=0;i<N;i++)
106       X[i] = P[i] + MULT16_32_Q14(y[i], g);
107 }
108
109 /** All the info necessary to keep track of a hypothesis during the search */
110 struct NBest {
111    float score;
112    float gain;
113    int sign;
114    int pos;
115    int orig;
116    float xy;
117    float yy;
118    float yp;
119 };
120
121 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
122 {
123    int L = 3;
124    VARDECL(float *p);
125    VARDECL(float *_y);
126    VARDECL(float *_ny);
127    VARDECL(int *_iy);
128    VARDECL(int *_iny);
129    VARDECL(float **y);
130    VARDECL(float **ny);
131    VARDECL(int **iy);
132    VARDECL(int **iny);
133    int i, j, k, m;
134    int pulsesLeft;
135    VARDECL(celt_word32_t *xy);
136    VARDECL(celt_word32_t *yy);
137    VARDECL(celt_word32_t *yp);
138    VARDECL(struct NBest *_nbest);
139    VARDECL(struct NBest **nbest);
140    celt_word32_t Rpp=0, Rxp=0;
141    int maxL = 1;
142    float _alpha = Q15_ONE_1*alpha;
143 #ifdef FIXED_POINT
144    int yshift = 14-EC_ILOG(K);
145 #endif
146
147    ALLOC(p, N, float);
148    ALLOC(_y, L*N, float);
149    ALLOC(_ny, L*N, float);
150    ALLOC(_iy, L*N, int);
151    ALLOC(_iny, L*N, int);
152    ALLOC(y, L*N, float*);
153    ALLOC(ny, L*N, float*);
154    ALLOC(iy, L*N, int*);
155    ALLOC(iny, L*N, int*);
156    
157    ALLOC(xy, L, celt_word32_t);
158    ALLOC(yy, L, celt_word32_t);
159    ALLOC(yp, L, celt_word32_t);
160    ALLOC(_nbest, L, struct NBest);
161    ALLOC(nbest, L, struct NBest *);
162
163    for (j=0;j<N;j++)
164    {
165       p[j] = P[j]*NORM_SCALING_1;
166    }
167    
168    for (m=0;m<L;m++)
169       nbest[m] = &_nbest[m];
170    
171    for (m=0;m<L;m++)
172    {
173       ny[m] = &_ny[m*N];
174       iny[m] = &_iny[m*N];
175       y[m] = &_y[m*N];
176       iy[m] = &_iy[m*N];
177    }
178    
179    for (j=0;j<N;j++)
180    {
181       Rpp = MAC16_16(Rpp, P[j],P[j]);
182       Rxp = MAC16_16(Rxp, X[j],P[j]);
183    }
184    Rpp = ROUND(Rpp, NORM_SHIFT);
185    Rxp = ROUND(Rxp, NORM_SHIFT);
186    if (Rpp>NORM_SCALING)
187       celt_fatal("Rpp > 1");
188
189    /* We only need to initialise the zero because the first iteration only uses that */
190    for (i=0;i<N;i++)
191       y[0][i] = 0;
192    for (i=0;i<N;i++)
193       iy[0][i] = 0;
194    xy[0] = yy[0] = yp[0] = 0;
195
196    pulsesLeft = K;
197    while (pulsesLeft > 0)
198    {
199       int pulsesAtOnce=1;
200       int Lupdate = L;
201       int L2 = L;
202       
203       /* Decide on complexity strategy */
204       pulsesAtOnce = pulsesLeft/N;
205       if (pulsesAtOnce<1)
206          pulsesAtOnce = 1;
207       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
208          Lupdate = 1;
209       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
210       L2 = Lupdate;
211       if (L2>maxL)
212       {
213          L2 = maxL;
214          maxL *= N;
215       }
216
217       for (m=0;m<Lupdate;m++)
218          nbest[m]->score = -1e10f;
219
220       for (m=0;m<L2;m++)
221       {
222          for (j=0;j<N;j++)
223          {
224             int sign;
225             /*if (x[j]>0) sign=1; else sign=-1;*/
226             for (sign=-1;sign<=1;sign+=2)
227             {
228                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
229                celt_word32_t tmp_xy, tmp_yy, tmp_yp;
230                float score;
231                float g;
232                float s = SHL16(sign*pulsesAtOnce, yshift);
233                
234                /* All pulses at one location must have the same sign. */
235                if (iy[m][j]*sign < 0)
236                   continue;
237
238                /* Updating the sums of the new pulse(s) */
239                tmp_xy = xy[m] + s*X[j]               - _alpha*s*P[j]*Rxp*NORM_SCALING_1;
240                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp*NORM_SCALING_1 - 2.f*_alpha*s*p[j]*yp[m]*NORM_SCALING_1 - 2.f*s*s*_alpha*p[j]*p[j];
241                tmp_yp = yp[m] + s*P[j]               *(1.f-_alpha*Rpp*NORM_SCALING_1);
242                
243                /* Compute the gain such that ||p + g*y|| = 1 */
244                g = (approx_sqrt(NORM_SCALING_1*NORM_SCALING_1*tmp_yp*tmp_yp + tmp_yy - NORM_SCALING_1*tmp_yy*Rpp) - tmp_yp*NORM_SCALING_1)*approx_inv(tmp_yy);
245                /* Knowing that gain, what the error: (x-g*y)^2 
246                   (result is negated and we discard x^2 because it's constant) */
247                score = 2.f*g*tmp_xy*NORM_SCALING_1 - g*g*tmp_yy;
248
249                if (score>nbest[Lupdate-1]->score)
250                {
251                   int k;
252                   int id = Lupdate-1;
253                   struct NBest *tmp_best;
254
255                   /* Save some pointers that would be deleted and use them for the current entry*/
256                   tmp_best = nbest[Lupdate-1];
257                   while (id > 0 && score > nbest[id-1]->score)
258                      id--;
259                
260                   for (k=Lupdate-1;k>id;k--)
261                      nbest[k] = nbest[k-1];
262
263                   nbest[id] = tmp_best;
264                   nbest[id]->score = score;
265                   nbest[id]->pos = j;
266                   nbest[id]->orig = m;
267                   nbest[id]->sign = sign;
268                   nbest[id]->gain = g;
269                   nbest[id]->xy = tmp_xy;
270                   nbest[id]->yy = tmp_yy;
271                   nbest[id]->yp = tmp_yp;
272                }
273             }
274          }
275
276       }
277       
278       if (!(nbest[0]->score > -1e10f))
279          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
280       /* Only now that we've made the final choice, update ny/iny and others */
281       for (k=0;k<Lupdate;k++)
282       {
283          int n;
284          int is;
285          float s;
286          is = nbest[k]->sign*pulsesAtOnce;
287          s = SHL16(is, yshift);
288          for (n=0;n<N;n++)
289             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
290          ny[k][nbest[k]->pos] += s;
291
292          for (n=0;n<N;n++)
293             iny[k][n] = iy[nbest[k]->orig][n];
294          iny[k][nbest[k]->pos] += is;
295
296          xy[k] = nbest[k]->xy;
297          yy[k] = nbest[k]->yy;
298          yp[k] = nbest[k]->yp;
299       }
300       /* Swap ny/iny with y/iy */
301       for (k=0;k<Lupdate;k++)
302       {
303          float *tmp_ny;
304          int *tmp_iny;
305
306          tmp_ny = ny[k];
307          ny[k] = y[k];
308          y[k] = tmp_ny;
309          tmp_iny = iny[k];
310          iny[k] = iy[k];
311          iy[k] = tmp_iny;
312       }
313       pulsesLeft -= pulsesAtOnce;
314    }
315    
316 #if 0
317    if (0) {
318       float err=0;
319       for (i=0;i<N;i++)
320          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
321       /*if (N<=10)
322         printf ("%f %d %d\n", err, K, N);*/
323    }
324    /* Sanity checks, don't bother */
325    if (0) {
326       for (i=0;i<N;i++)
327          x[i] = p[i]+nbest[0]->gain*y[0][i];
328       float E=1e-15;
329       int ABS = 0;
330       for (i=0;i<N;i++)
331          ABS += abs(iy[0][i]);
332       /*if (K != ABS)
333          printf ("%d %d\n", K, ABS);*/
334       for (i=0;i<N;i++)
335          E += x[i]*x[i];
336       /*printf ("%f\n", E);*/
337       E = 1/sqrt(E);
338       for (i=0;i<N;i++)
339          x[i] *= E;
340    }
341 #endif
342    
343    encode_pulses(iy[0], N, K, enc);
344    
345    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
346       due to the recursive computation used in quantisation.
347       Not quite sure whether we need that or not */
348    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
349 }
350
351 /** Decode pulse vector and combine the result with the pitch vector to produce
352     the final normalised signal in the current band. */
353 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
354 {
355    VARDECL(int *iy);
356    ALLOC(iy, N, int);
357    decode_pulses(iy, N, K, dec);
358    mix_pitch_and_residual(iy, X, N, K, P, alpha);
359 }
360
361
362 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
363
364 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
365 {
366    int i,j;
367    int best=0;
368    float best_score=0;
369    float s = 1;
370    int sign;
371    float E;
372    float pred_gain;
373    int max_pos = N0-N/B;
374    if (max_pos > 32)
375       max_pos = 32;
376
377    for (i=0;i<max_pos*B;i+=B)
378    {
379       int j;
380       float xy=0, yy=0;
381       float score;
382       for (j=0;j<N;j++)
383       {
384          xy += 1.f*x[j]*Y[i+N-j-1];
385          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
386       }
387       score = xy*xy/(.001+yy);
388       if (score > best_score)
389       {
390          best_score = score;
391          best = i;
392          if (xy>0)
393             s = 1;
394          else
395             s = -1;
396       }
397    }
398    if (s<0)
399       sign = 1;
400    else
401       sign = 0;
402    /*printf ("%d %d ", sign, best);*/
403    ec_enc_uint(enc,sign,2);
404    ec_enc_uint(enc,best/B,max_pos);
405    /*printf ("%d %f\n", best, best_score);*/
406    
407    if (K>10)
408       pred_gain = pg[10];
409    else
410       pred_gain = pg[K];
411    E = 1e-10;
412    for (j=0;j<N;j++)
413    {
414       P[j] = s*Y[best+N-j-1];
415       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
416    }
417    E = pred_gain/sqrt(E);
418    for (j=0;j<N;j++)
419       P[j] *= E;
420    if (K>0)
421    {
422       for (j=0;j<N;j++)
423          x[j] -= P[j];
424    } else {
425       for (j=0;j<N;j++)
426          x[j] = P[j];
427    }
428    /*printf ("quant ");*/
429    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
430
431 }
432
433 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
434 {
435    int j;
436    int sign;
437    float s;
438    int best;
439    float E;
440    float pred_gain;
441    int max_pos = N0-N/B;
442    if (max_pos > 32)
443       max_pos = 32;
444    
445    sign = ec_dec_uint(dec, 2);
446    if (sign == 0)
447       s = 1;
448    else
449       s = -1;
450    
451    best = B*ec_dec_uint(dec, max_pos);
452    /*printf ("%d %d ", sign, best);*/
453
454    if (K>10)
455       pred_gain = pg[10];
456    else
457       pred_gain = pg[K];
458    E = 1e-10;
459    for (j=0;j<N;j++)
460    {
461       P[j] = s*Y[best+N-j-1];
462       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
463    }
464    E = pred_gain/sqrt(E);
465    for (j=0;j<N;j++)
466       P[j] *= E;
467    if (K==0)
468    {
469       for (j=0;j<N;j++)
470          x[j] = P[j];
471    }
472 }
473
474 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
475 {
476    int i, j;
477    float E;
478    
479    E = 1e-10;
480    if (N0 >= Nmax/2)
481    {
482       for (i=0;i<B;i++)
483       {
484          for (j=0;j<N/B;j++)
485          {
486             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
487             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
488          }
489       }
490    } else {
491       for (j=0;j<N;j++)
492       {
493          P[j] = Y[j];
494          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
495       }
496    }
497    E = 1.f/sqrt(E);
498    for (j=0;j<N;j++)
499       P[j] *= E;
500    for (j=0;j<N;j++)
501       x[j] = P[j];
502 }
503