fixed-point: converted intra prediction and folding, unb0rked mixed-precision
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72    SAVE_STACK;
73 #ifdef FIXED_POINT
74    int yshift = 14-EC_ILOG(K);
75 #endif
76    ALLOC(y, N, celt_norm_t);
77
78    /*for (i=0;i<N;i++)
79    printf ("%d ", iy[i]);*/
80    Rpp = 0;
81    for (i=0;i<N;i++)
82       Rpp = MAC16_16(Rpp,P[i],P[i]);
83
84    Ryp = 0;
85    for (i=0;i<N;i++)
86       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
87
88    /* Remove part of the pitch component to compute the real residual from
89       the encoded (int) one */
90    for (i=0;i<N;i++)
91       y[i] = SUB16(SHL16(iy[i],yshift),
92                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
93
94    /* Recompute after the projection (I think it's right) */
95    Ryp = 0;
96    for (i=0;i<N;i++)
97       Ryp = MAC16_16(Ryp,y[i],P[i]);
98
99    Ryy = 0;
100    for (i=0;i<N;i++)
101       Ryy = MAC16_16(Ryy, y[i],y[i]);
102
103    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
104    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
105
106    for (i=0;i<N;i++)
107       X[i] = P[i] + MULT16_32_Q14(y[i], g);
108    RESTORE_STACK;
109 }
110
111 /** All the info necessary to keep track of a hypothesis during the search */
112 struct NBest {
113    celt_word32_t score;
114    int sign;
115    int pos;
116    int orig;
117    celt_word32_t xy;
118    celt_word32_t yy;
119    celt_word32_t yp;
120 };
121
122 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
123 {
124    int L = 3;
125    VARDECL(celt_norm_t *_y);
126    VARDECL(celt_norm_t *_ny);
127    VARDECL(int *_iy);
128    VARDECL(int *_iny);
129    VARDECL(celt_norm_t **y);
130    VARDECL(celt_norm_t **ny);
131    VARDECL(int **iy);
132    VARDECL(int **iny);
133    SAVE_STACK;
134    int i, j, k, m;
135    int pulsesLeft;
136    VARDECL(celt_word32_t *xy);
137    VARDECL(celt_word32_t *yy);
138    VARDECL(celt_word32_t *yp);
139    VARDECL(struct NBest *_nbest);
140    VARDECL(struct NBest **nbest);
141    celt_word32_t Rpp=0, Rxp=0;
142    int maxL = 1;
143 #ifdef FIXED_POINT
144    int yshift = 14-EC_ILOG(K);
145 #endif
146
147    ALLOC(_y, L*N, celt_norm_t);
148    ALLOC(_ny, L*N, celt_norm_t);
149    ALLOC(_iy, L*N, int);
150    ALLOC(_iny, L*N, int);
151    ALLOC(y, L*N, celt_norm_t*);
152    ALLOC(ny, L*N, celt_norm_t*);
153    ALLOC(iy, L*N, int*);
154    ALLOC(iny, L*N, int*);
155    
156    ALLOC(xy, L, celt_word32_t);
157    ALLOC(yy, L, celt_word32_t);
158    ALLOC(yp, L, celt_word32_t);
159    ALLOC(_nbest, L, struct NBest);
160    ALLOC(nbest, L, struct NBest *);
161    
162    for (m=0;m<L;m++)
163       nbest[m] = &_nbest[m];
164    
165    for (m=0;m<L;m++)
166    {
167       ny[m] = &_ny[m*N];
168       iny[m] = &_iny[m*N];
169       y[m] = &_y[m*N];
170       iy[m] = &_iy[m*N];
171    }
172    
173    for (j=0;j<N;j++)
174    {
175       Rpp = MAC16_16(Rpp, P[j],P[j]);
176       Rxp = MAC16_16(Rxp, X[j],P[j]);
177    }
178    Rpp = ROUND(Rpp, NORM_SHIFT);
179    Rxp = ROUND(Rxp, NORM_SHIFT);
180    if (Rpp>NORM_SCALING)
181       celt_fatal("Rpp > 1");
182
183    /* We only need to initialise the zero because the first iteration only uses that */
184    for (i=0;i<N;i++)
185       y[0][i] = 0;
186    for (i=0;i<N;i++)
187       iy[0][i] = 0;
188    xy[0] = yy[0] = yp[0] = 0;
189
190    pulsesLeft = K;
191    while (pulsesLeft > 0)
192    {
193       int pulsesAtOnce=1;
194       int Lupdate = L;
195       int L2 = L;
196       
197       /* Decide on complexity strategy */
198       pulsesAtOnce = pulsesLeft/N;
199       if (pulsesAtOnce<1)
200          pulsesAtOnce = 1;
201       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
202          Lupdate = 1;
203       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
204       L2 = Lupdate;
205       if (L2>maxL)
206       {
207          L2 = maxL;
208          maxL *= N;
209       }
210
211       for (m=0;m<Lupdate;m++)
212          nbest[m]->score = -VERY_LARGE32;
213
214       for (m=0;m<L2;m++)
215       {
216          for (j=0;j<N;j++)
217          {
218             int sign;
219             /*if (x[j]>0) sign=1; else sign=-1;*/
220             for (sign=-1;sign<=1;sign+=2)
221             {
222                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
223                celt_word32_t Rxy, Ryy, Ryp;
224                celt_word16_t spj, aspj; /* Intermediate results */
225                celt_word32_t score;
226                celt_word32_t g;
227                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
228                
229                /* All pulses at one location must have the same sign. */
230                if (iy[m][j]*sign < 0)
231                   continue;
232
233                spj = MULT16_16_P14(s, P[j]);
234                aspj = MULT16_16_P15(alpha, spj);
235                /* Updating the sums of the new pulse(s) */
236                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
237                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
238                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
239                
240                /* Compute the gain such that ||p + g*y|| = 1 */
241                g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),Rpp)) - ROUND(Ryp,14),14),ROUND(Ryy,14));
242                //g *= NORM_SCALING_1;
243                /* Knowing that gain, what the error: (x-g*y)^2 
244                   (result is negated and we discard x^2 because it's constant) */
245                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
246                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g) -
247                      MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
248
249                if (score>nbest[Lupdate-1]->score)
250                {
251                   int k;
252                   int id = Lupdate-1;
253                   struct NBest *tmp_best;
254
255                   /* Save some pointers that would be deleted and use them for the current entry*/
256                   tmp_best = nbest[Lupdate-1];
257                   while (id > 0 && score > nbest[id-1]->score)
258                      id--;
259                
260                   for (k=Lupdate-1;k>id;k--)
261                      nbest[k] = nbest[k-1];
262
263                   nbest[id] = tmp_best;
264                   nbest[id]->score = score;
265                   nbest[id]->pos = j;
266                   nbest[id]->orig = m;
267                   nbest[id]->sign = sign;
268                   nbest[id]->xy = Rxy;
269                   nbest[id]->yy = Ryy;
270                   nbest[id]->yp = Ryp;
271                }
272             }
273          }
274
275       }
276       
277       if (!(nbest[0]->score > -VERY_LARGE32))
278          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
279       /* Only now that we've made the final choice, update ny/iny and others */
280       for (k=0;k<Lupdate;k++)
281       {
282          int n;
283          int is;
284          celt_norm_t s;
285          is = nbest[k]->sign*pulsesAtOnce;
286          s = SHL16(is, yshift);
287          for (n=0;n<N;n++)
288             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
289          ny[k][nbest[k]->pos] += s;
290
291          for (n=0;n<N;n++)
292             iny[k][n] = iy[nbest[k]->orig][n];
293          iny[k][nbest[k]->pos] += is;
294
295          xy[k] = nbest[k]->xy;
296          yy[k] = nbest[k]->yy;
297          yp[k] = nbest[k]->yp;
298       }
299       /* Swap ny/iny with y/iy */
300       for (k=0;k<Lupdate;k++)
301       {
302          celt_norm_t *tmp_ny;
303          int *tmp_iny;
304
305          tmp_ny = ny[k];
306          ny[k] = y[k];
307          y[k] = tmp_ny;
308          tmp_iny = iny[k];
309          iny[k] = iy[k];
310          iy[k] = tmp_iny;
311       }
312       pulsesLeft -= pulsesAtOnce;
313    }
314    
315 #if 0
316    if (0) {
317       celt_word32_t err=0;
318       for (i=0;i<N;i++)
319          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
320       /*if (N<=10)
321         printf ("%f %d %d\n", err, K, N);*/
322    }
323    /* Sanity checks, don't bother */
324    if (0) {
325       for (i=0;i<N;i++)
326          x[i] = p[i]+nbest[0]->gain*y[0][i];
327       celt_word32_t E=1e-15;
328       int ABS = 0;
329       for (i=0;i<N;i++)
330          ABS += abs(iy[0][i]);
331       /*if (K != ABS)
332          printf ("%d %d\n", K, ABS);*/
333       for (i=0;i<N;i++)
334          E += x[i]*x[i];
335       /*printf ("%f\n", E);*/
336       E = 1/sqrt(E);
337       for (i=0;i<N;i++)
338          x[i] *= E;
339    }
340 #endif
341    
342    encode_pulses(iy[0], N, K, enc);
343    
344    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
345       due to the recursive computation used in quantisation.
346       Not quite sure whether we need that or not */
347    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
348    RESTORE_STACK;
349 }
350
351 /** Decode pulse vector and combine the result with the pitch vector to produce
352     the final normalised signal in the current band. */
353 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
354 {
355    VARDECL(int *iy);
356    SAVE_STACK;
357    ALLOC(iy, N, int);
358    decode_pulses(iy, N, K, dec);
359    mix_pitch_and_residual(iy, X, N, K, P, alpha);
360    RESTORE_STACK;
361 }
362
363
364 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
365
366 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
367 {
368    int i,j;
369    int best=0;
370    celt_word32_t best_score=0;
371    celt_word16_t s = 1;
372    int sign;
373    celt_word32_t E;
374    float pred_gain;
375    int max_pos = N0-N/B;
376    if (max_pos > 32)
377       max_pos = 32;
378
379    for (i=0;i<max_pos*B;i+=B)
380    {
381       int j;
382       celt_word32_t xy=0, yy=0;
383       float score;
384       for (j=0;j<N;j++)
385       {
386          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
387          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
388       }
389       score = 1.f*xy*xy/(.001+yy);
390       if (score > best_score)
391       {
392          best_score = score;
393          best = i;
394          if (xy>0)
395             s = 1;
396          else
397             s = -1;
398       }
399    }
400    if (s<0)
401       sign = 1;
402    else
403       sign = 0;
404    /*printf ("%d %d ", sign, best);*/
405    ec_enc_uint(enc,sign,2);
406    ec_enc_uint(enc,best/B,max_pos);
407    /*printf ("%d %f\n", best, best_score);*/
408    
409    if (K>10)
410       pred_gain = pg[10];
411    else
412       pred_gain = pg[K];
413    E = 1e-10;
414    for (j=0;j<N;j++)
415    {
416       P[j] = s*Y[best+N-j-1];
417       E = MAC16_16(E, P[j],P[j]);
418    }
419    pred_gain = NORM_SCALING*pred_gain/sqrt(E);
420    for (j=0;j<N;j++)
421       P[j] *= pred_gain;
422    if (K>0)
423    {
424       for (j=0;j<N;j++)
425          x[j] -= P[j];
426    } else {
427       for (j=0;j<N;j++)
428          x[j] = P[j];
429    }
430    /*printf ("quant ");*/
431    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
432
433 }
434
435 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
436 {
437    int j;
438    int sign;
439    celt_word16_t s;
440    int best;
441    celt_word32_t E;
442    float pred_gain;
443    int max_pos = N0-N/B;
444    if (max_pos > 32)
445       max_pos = 32;
446    
447    sign = ec_dec_uint(dec, 2);
448    if (sign == 0)
449       s = 1;
450    else
451       s = -1;
452    
453    best = B*ec_dec_uint(dec, max_pos);
454    /*printf ("%d %d ", sign, best);*/
455
456    if (K>10)
457       pred_gain = pg[10];
458    else
459       pred_gain = pg[K];
460    E = 1e-10;
461    for (j=0;j<N;j++)
462    {
463       P[j] = s*Y[best+N-j-1];
464       E = MAC16_16(E, P[j],P[j]);
465    }
466    pred_gain = NORM_SCALING*pred_gain/sqrt(E);
467    for (j=0;j<N;j++)
468       P[j] *= pred_gain;
469    if (K==0)
470    {
471       for (j=0;j<N;j++)
472          x[j] = P[j];
473    }
474 }
475
476 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
477 {
478    int i, j;
479    celt_word32_t E;
480    float g;
481    
482    E = 1e-10;
483    if (N0 >= Nmax/2)
484    {
485       for (i=0;i<B;i++)
486       {
487          for (j=0;j<N/B;j++)
488          {
489             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
490             E += P[j*B+i]*P[j*B+i];
491          }
492       }
493    } else {
494       for (j=0;j<N;j++)
495       {
496          P[j] = Y[j];
497          E = MAC16_16(E, P[j],P[j]);
498       }
499    }
500    g = NORM_SCALING/sqrt(E);
501    for (j=0;j<N;j++)
502       P[j] *= g;
503    for (j=0;j<N;j++)
504       x[j] = P[j];
505 }
506