fixed-point: First check-point in alg_quant() conversion
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift = 14-EC_ILOG(K);
74 #endif
75    ALLOC(y, N, celt_norm_t);
76
77    /*for (i=0;i<N;i++)
78    printf ("%d ", iy[i]);*/
79    Rpp = 0;
80    for (i=0;i<N;i++)
81       Rpp = MAC16_16(Rpp,P[i],P[i]);
82
83    Ryp = 0;
84    for (i=0;i<N;i++)
85       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
86
87    /* Remove part of the pitch component to compute the real residual from
88       the encoded (int) one */
89    for (i=0;i<N;i++)
90       y[i] = SUB16(SHL16(iy[i],yshift),
91                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
92
93    /* Recompute after the projection (I think it's right) */
94    Ryp = 0;
95    for (i=0;i<N;i++)
96       Ryp = MAC16_16(Ryp,y[i],P[i]);
97
98    Ryy = 0;
99    for (i=0;i<N;i++)
100       Ryy = MAC16_16(Ryy, y[i],y[i]);
101
102    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
103    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
104
105    for (i=0;i<N;i++)
106       X[i] = P[i] + MULT16_32_Q14(y[i], g);
107 }
108
109 /** All the info necessary to keep track of a hypothesis during the search */
110 struct NBest {
111    float score;
112    float gain;
113    int sign;
114    int pos;
115    int orig;
116    float xy;
117    float yy;
118    float yp;
119 };
120
121 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
122 {
123    int L = 3;
124    VARDECL(float *x);
125    VARDECL(float *p);
126    VARDECL(float *_y);
127    VARDECL(float *_ny);
128    VARDECL(int *_iy);
129    VARDECL(int *_iny);
130    VARDECL(float **y);
131    VARDECL(float **ny);
132    VARDECL(int **iy);
133    VARDECL(int **iny);
134    int i, j, k, m;
135    int pulsesLeft;
136    VARDECL(float *xy);
137    VARDECL(float *yy);
138    VARDECL(float *yp);
139    VARDECL(struct NBest *_nbest);
140    VARDECL(struct NBest **nbest);
141    celt_word32_t Rpp=0, Rxp=0;
142    int maxL = 1;
143    float _alpha = Q15_ONE_1*alpha;
144 #ifdef FIXED_POINT
145    int yshift = 14-EC_ILOG(K);
146 #endif
147
148    ALLOC(x, N, float);
149    ALLOC(p, N, float);
150    ALLOC(_y, L*N, float);
151    ALLOC(_ny, L*N, float);
152    ALLOC(_iy, L*N, int);
153    ALLOC(_iny, L*N, int);
154    ALLOC(y, L*N, float*);
155    ALLOC(ny, L*N, float*);
156    ALLOC(iy, L*N, int*);
157    ALLOC(iny, L*N, int*);
158    
159    ALLOC(xy, L, float);
160    ALLOC(yy, L, float);
161    ALLOC(yp, L, float);
162    ALLOC(_nbest, L, struct NBest);
163    ALLOC(nbest, L, struct NBest *);
164
165    for (j=0;j<N;j++)
166    {
167       x[j] = X[j]*NORM_SCALING_1;
168       p[j] = P[j]*NORM_SCALING_1;
169    }
170    
171    for (m=0;m<L;m++)
172       nbest[m] = &_nbest[m];
173    
174    for (m=0;m<L;m++)
175    {
176       ny[m] = &_ny[m*N];
177       iny[m] = &_iny[m*N];
178       y[m] = &_y[m*N];
179       iy[m] = &_iy[m*N];
180    }
181    
182    for (j=0;j<N;j++)
183    {
184       Rpp = MAC16_16(Rpp, P[j],P[j]);
185       Rxp = MAC16_16(Rxp, X[j],P[j]);
186    }
187    Rpp = ROUND(Rpp, NORM_SHIFT);
188    Rxp = ROUND(Rxp, NORM_SHIFT);
189    if (Rpp>NORM_SCALING)
190       celt_fatal("Rpp > 1");
191
192    /* We only need to initialise the zero because the first iteration only uses that */
193    for (i=0;i<N;i++)
194       y[0][i] = 0;
195    for (i=0;i<N;i++)
196       iy[0][i] = 0;
197    xy[0] = yy[0] = yp[0] = 0;
198
199    pulsesLeft = K;
200    while (pulsesLeft > 0)
201    {
202       int pulsesAtOnce=1;
203       int Lupdate = L;
204       int L2 = L;
205       
206       /* Decide on complexity strategy */
207       pulsesAtOnce = pulsesLeft/N;
208       if (pulsesAtOnce<1)
209          pulsesAtOnce = 1;
210       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
211          Lupdate = 1;
212       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
213       L2 = Lupdate;
214       if (L2>maxL)
215       {
216          L2 = maxL;
217          maxL *= N;
218       }
219
220       for (m=0;m<Lupdate;m++)
221          nbest[m]->score = -1e10f;
222
223       for (m=0;m<L2;m++)
224       {
225          for (j=0;j<N;j++)
226          {
227             int sign;
228             /*if (x[j]>0) sign=1; else sign=-1;*/
229             for (sign=-1;sign<=1;sign+=2)
230             {
231                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
232                float tmp_xy, tmp_yy, tmp_yp;
233                float score;
234                float g;
235                float s = SHL16(sign*pulsesAtOnce, yshift);
236                
237                /* All pulses at one location must have the same sign. */
238                if (iy[m][j]*sign < 0)
239                   continue;
240
241                /* Updating the sums of the new pulse(s) */
242                tmp_xy = xy[m] + s*X[j]               - _alpha*s*P[j]*Rxp*NORM_SCALING_1;
243                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp*NORM_SCALING_1 - 2.f*_alpha*s*p[j]*yp[m] - 2.f*s*s*_alpha*p[j]*p[j];
244                tmp_yp = yp[m] + s*p[j]               *(1.f-_alpha*Rpp*NORM_SCALING_1);
245                
246                /* Compute the gain such that ||p + g*y|| = 1 */
247                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp*NORM_SCALING_1) - tmp_yp)*approx_inv(tmp_yy);
248                /* Knowing that gain, what the error: (x-g*y)^2 
249                   (result is negated and we discard x^2 because it's constant) */
250                score = 2.f*g*tmp_xy*NORM_SCALING_1 - g*g*tmp_yy;
251
252                if (score>nbest[Lupdate-1]->score)
253                {
254                   int k;
255                   int id = Lupdate-1;
256                   struct NBest *tmp_best;
257
258                   /* Save some pointers that would be deleted and use them for the current entry*/
259                   tmp_best = nbest[Lupdate-1];
260                   while (id > 0 && score > nbest[id-1]->score)
261                      id--;
262                
263                   for (k=Lupdate-1;k>id;k--)
264                      nbest[k] = nbest[k-1];
265
266                   nbest[id] = tmp_best;
267                   nbest[id]->score = score;
268                   nbest[id]->pos = j;
269                   nbest[id]->orig = m;
270                   nbest[id]->sign = sign;
271                   nbest[id]->gain = g;
272                   nbest[id]->xy = tmp_xy;
273                   nbest[id]->yy = tmp_yy;
274                   nbest[id]->yp = tmp_yp;
275                }
276             }
277          }
278
279       }
280       
281       if (!(nbest[0]->score > -1e10f))
282          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
283       /* Only now that we've made the final choice, update ny/iny and others */
284       for (k=0;k<Lupdate;k++)
285       {
286          int n;
287          int is;
288          float s;
289          is = nbest[k]->sign*pulsesAtOnce;
290          s = SHL16(is, yshift);
291          for (n=0;n<N;n++)
292             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
293          ny[k][nbest[k]->pos] += s;
294
295          for (n=0;n<N;n++)
296             iny[k][n] = iy[nbest[k]->orig][n];
297          iny[k][nbest[k]->pos] += is;
298
299          xy[k] = nbest[k]->xy;
300          yy[k] = nbest[k]->yy;
301          yp[k] = nbest[k]->yp;
302       }
303       /* Swap ny/iny with y/iy */
304       for (k=0;k<Lupdate;k++)
305       {
306          float *tmp_ny;
307          int *tmp_iny;
308
309          tmp_ny = ny[k];
310          ny[k] = y[k];
311          y[k] = tmp_ny;
312          tmp_iny = iny[k];
313          iny[k] = iy[k];
314          iy[k] = tmp_iny;
315       }
316       pulsesLeft -= pulsesAtOnce;
317    }
318    
319 #if 0
320    if (0) {
321       float err=0;
322       for (i=0;i<N;i++)
323          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
324       /*if (N<=10)
325         printf ("%f %d %d\n", err, K, N);*/
326    }
327    /* Sanity checks, don't bother */
328    if (0) {
329       for (i=0;i<N;i++)
330          x[i] = p[i]+nbest[0]->gain*y[0][i];
331       float E=1e-15;
332       int ABS = 0;
333       for (i=0;i<N;i++)
334          ABS += abs(iy[0][i]);
335       /*if (K != ABS)
336          printf ("%d %d\n", K, ABS);*/
337       for (i=0;i<N;i++)
338          E += x[i]*x[i];
339       /*printf ("%f\n", E);*/
340       E = 1/sqrt(E);
341       for (i=0;i<N;i++)
342          x[i] *= E;
343    }
344 #endif
345    
346    encode_pulses(iy[0], N, K, enc);
347    
348    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
349       due to the recursive computation used in quantisation.
350       Not quite sure whether we need that or not */
351    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
352 }
353
354 /** Decode pulse vector and combine the result with the pitch vector to produce
355     the final normalised signal in the current band. */
356 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
357 {
358    VARDECL(int *iy);
359    ALLOC(iy, N, int);
360    decode_pulses(iy, N, K, dec);
361    mix_pitch_and_residual(iy, X, N, K, P, alpha);
362 }
363
364
365 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
366
367 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
368 {
369    int i,j;
370    int best=0;
371    float best_score=0;
372    float s = 1;
373    int sign;
374    float E;
375    float pred_gain;
376    int max_pos = N0-N/B;
377    if (max_pos > 32)
378       max_pos = 32;
379
380    for (i=0;i<max_pos*B;i+=B)
381    {
382       int j;
383       float xy=0, yy=0;
384       float score;
385       for (j=0;j<N;j++)
386       {
387          xy += 1.f*x[j]*Y[i+N-j-1];
388          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
389       }
390       score = xy*xy/(.001+yy);
391       if (score > best_score)
392       {
393          best_score = score;
394          best = i;
395          if (xy>0)
396             s = 1;
397          else
398             s = -1;
399       }
400    }
401    if (s<0)
402       sign = 1;
403    else
404       sign = 0;
405    /*printf ("%d %d ", sign, best);*/
406    ec_enc_uint(enc,sign,2);
407    ec_enc_uint(enc,best/B,max_pos);
408    /*printf ("%d %f\n", best, best_score);*/
409    
410    if (K>10)
411       pred_gain = pg[10];
412    else
413       pred_gain = pg[K];
414    E = 1e-10;
415    for (j=0;j<N;j++)
416    {
417       P[j] = s*Y[best+N-j-1];
418       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
419    }
420    E = pred_gain/sqrt(E);
421    for (j=0;j<N;j++)
422       P[j] *= E;
423    if (K>0)
424    {
425       for (j=0;j<N;j++)
426          x[j] -= P[j];
427    } else {
428       for (j=0;j<N;j++)
429          x[j] = P[j];
430    }
431    /*printf ("quant ");*/
432    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
433
434 }
435
436 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
437 {
438    int j;
439    int sign;
440    float s;
441    int best;
442    float E;
443    float pred_gain;
444    int max_pos = N0-N/B;
445    if (max_pos > 32)
446       max_pos = 32;
447    
448    sign = ec_dec_uint(dec, 2);
449    if (sign == 0)
450       s = 1;
451    else
452       s = -1;
453    
454    best = B*ec_dec_uint(dec, max_pos);
455    /*printf ("%d %d ", sign, best);*/
456
457    if (K>10)
458       pred_gain = pg[10];
459    else
460       pred_gain = pg[K];
461    E = 1e-10;
462    for (j=0;j<N;j++)
463    {
464       P[j] = s*Y[best+N-j-1];
465       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
466    }
467    E = pred_gain/sqrt(E);
468    for (j=0;j<N;j++)
469       P[j] *= E;
470    if (K==0)
471    {
472       for (j=0;j<N;j++)
473          x[j] = P[j];
474    }
475 }
476
477 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
478 {
479    int i, j;
480    float E;
481    
482    E = 1e-10;
483    if (N0 >= Nmax/2)
484    {
485       for (i=0;i<B;i++)
486       {
487          for (j=0;j<N/B;j++)
488          {
489             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
490             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
491          }
492       }
493    } else {
494       for (j=0;j<N;j++)
495       {
496          P[j] = Y[j];
497          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
498       }
499    }
500    E = 1.f/sqrt(E);
501    for (j=0;j<N;j++)
502       P[j] *= E;
503    for (j=0;j<N;j++)
504       x[j] = P[j];
505 }
506