fixed-point: finished intra_prediction(). No float ops left in vq.c
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift;
74 #endif
75    SAVE_STACK;
76 #ifdef FIXED_POINT
77    yshift = 14-EC_ILOG(K);
78 #endif
79    ALLOC(y, N, celt_norm_t);
80
81    /*for (i=0;i<N;i++)
82    printf ("%d ", iy[i]);*/
83    Rpp = 0;
84    for (i=0;i<N;i++)
85       Rpp = MAC16_16(Rpp,P[i],P[i]);
86
87    Ryp = 0;
88    for (i=0;i<N;i++)
89       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
90
91    /* Remove part of the pitch component to compute the real residual from
92       the encoded (int) one */
93    for (i=0;i<N;i++)
94       y[i] = SUB16(SHL16(iy[i],yshift),
95                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
96
97    /* Recompute after the projection (I think it's right) */
98    Ryp = 0;
99    for (i=0;i<N;i++)
100       Ryp = MAC16_16(Ryp,y[i],P[i]);
101
102    Ryy = 0;
103    for (i=0;i<N;i++)
104       Ryy = MAC16_16(Ryy, y[i],y[i]);
105
106    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
107    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
108
109    for (i=0;i<N;i++)
110       X[i] = P[i] + MULT16_32_Q14(y[i], g);
111    RESTORE_STACK;
112 }
113
114 /** All the info necessary to keep track of a hypothesis during the search */
115 struct NBest {
116    celt_word32_t score;
117    int sign;
118    int pos;
119    int orig;
120    celt_word32_t xy;
121    celt_word32_t yy;
122    celt_word32_t yp;
123 };
124
125 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
126 {
127    int L = 3;
128    VARDECL(celt_norm_t *_y);
129    VARDECL(celt_norm_t *_ny);
130    VARDECL(int *_iy);
131    VARDECL(int *_iny);
132    VARDECL(celt_norm_t **y);
133    VARDECL(celt_norm_t **ny);
134    VARDECL(int **iy);
135    VARDECL(int **iny);
136    int i, j, k, m;
137    int pulsesLeft;
138    VARDECL(celt_word32_t *xy);
139    VARDECL(celt_word32_t *yy);
140    VARDECL(celt_word32_t *yp);
141    VARDECL(struct NBest *_nbest);
142    VARDECL(struct NBest **nbest);
143    celt_word32_t Rpp=0, Rxp=0;
144    int maxL = 1;
145 #ifdef FIXED_POINT
146    int yshift;
147 #endif
148    SAVE_STACK;
149
150 #ifdef FIXED_POINT
151    yshift = 14-EC_ILOG(K);
152 #endif
153
154    ALLOC(_y, L*N, celt_norm_t);
155    ALLOC(_ny, L*N, celt_norm_t);
156    ALLOC(_iy, L*N, int);
157    ALLOC(_iny, L*N, int);
158    ALLOC(y, L, celt_norm_t*);
159    ALLOC(ny, L, celt_norm_t*);
160    ALLOC(iy, L, int*);
161    ALLOC(iny, L, int*);
162    
163    ALLOC(xy, L, celt_word32_t);
164    ALLOC(yy, L, celt_word32_t);
165    ALLOC(yp, L, celt_word32_t);
166    ALLOC(_nbest, L, struct NBest);
167    ALLOC(nbest, L, struct NBest *);
168    
169    for (m=0;m<L;m++)
170       nbest[m] = &_nbest[m];
171    
172    for (m=0;m<L;m++)
173    {
174       ny[m] = &_ny[m*N];
175       iny[m] = &_iny[m*N];
176       y[m] = &_y[m*N];
177       iy[m] = &_iy[m*N];
178    }
179    
180    for (j=0;j<N;j++)
181    {
182       Rpp = MAC16_16(Rpp, P[j],P[j]);
183       Rxp = MAC16_16(Rxp, X[j],P[j]);
184    }
185    Rpp = ROUND(Rpp, NORM_SHIFT);
186    Rxp = ROUND(Rxp, NORM_SHIFT);
187    if (Rpp>NORM_SCALING)
188       celt_fatal("Rpp > 1");
189
190    /* We only need to initialise the zero because the first iteration only uses that */
191    for (i=0;i<N;i++)
192       y[0][i] = 0;
193    for (i=0;i<N;i++)
194       iy[0][i] = 0;
195    xy[0] = yy[0] = yp[0] = 0;
196
197    pulsesLeft = K;
198    while (pulsesLeft > 0)
199    {
200       int pulsesAtOnce=1;
201       int Lupdate = L;
202       int L2 = L;
203       
204       /* Decide on complexity strategy */
205       pulsesAtOnce = pulsesLeft/N;
206       if (pulsesAtOnce<1)
207          pulsesAtOnce = 1;
208       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
209          Lupdate = 1;
210       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
211       L2 = Lupdate;
212       if (L2>maxL)
213       {
214          L2 = maxL;
215          maxL *= N;
216       }
217
218       for (m=0;m<Lupdate;m++)
219          nbest[m]->score = -VERY_LARGE32;
220
221       for (m=0;m<L2;m++)
222       {
223          for (j=0;j<N;j++)
224          {
225             int sign;
226             /*if (x[j]>0) sign=1; else sign=-1;*/
227             for (sign=-1;sign<=1;sign+=2)
228             {
229                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
230                celt_word32_t Rxy, Ryy, Ryp;
231                celt_word16_t spj, aspj; /* Intermediate results */
232                celt_word32_t score;
233                celt_word32_t g;
234                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
235                
236                /* All pulses at one location must have the same sign. */
237                if (iy[m][j]*sign < 0)
238                   continue;
239
240                spj = MULT16_16_P14(s, P[j]);
241                aspj = MULT16_16_P15(alpha, spj);
242                /* Updating the sums of the new pulse(s) */
243                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
244                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
245                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
246                
247                /* Compute the gain such that ||p + g*y|| = 1 */
248                g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),Rpp)) - ROUND(Ryp,14),14),ROUND(Ryy,14));
249                
250                /* Knowing that gain, what the error: (x-g*y)^2 
251                   (result is negated and we discard x^2 because it's constant) */
252                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
253                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g) -
254                      MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
255
256                if (score>nbest[Lupdate-1]->score)
257                {
258                   int k;
259                   int id = Lupdate-1;
260                   struct NBest *tmp_best;
261
262                   /* Save some pointers that would be deleted and use them for the current entry*/
263                   tmp_best = nbest[Lupdate-1];
264                   while (id > 0 && score > nbest[id-1]->score)
265                      id--;
266                
267                   for (k=Lupdate-1;k>id;k--)
268                      nbest[k] = nbest[k-1];
269
270                   nbest[id] = tmp_best;
271                   nbest[id]->score = score;
272                   nbest[id]->pos = j;
273                   nbest[id]->orig = m;
274                   nbest[id]->sign = sign;
275                   nbest[id]->xy = Rxy;
276                   nbest[id]->yy = Ryy;
277                   nbest[id]->yp = Ryp;
278                }
279             }
280          }
281
282       }
283       
284       if (!(nbest[0]->score > -VERY_LARGE32))
285          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
286       /* Only now that we've made the final choice, update ny/iny and others */
287       for (k=0;k<Lupdate;k++)
288       {
289          int n;
290          int is;
291          celt_norm_t s;
292          is = nbest[k]->sign*pulsesAtOnce;
293          s = SHL16(is, yshift);
294          for (n=0;n<N;n++)
295             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
296          ny[k][nbest[k]->pos] += s;
297
298          for (n=0;n<N;n++)
299             iny[k][n] = iy[nbest[k]->orig][n];
300          iny[k][nbest[k]->pos] += is;
301
302          xy[k] = nbest[k]->xy;
303          yy[k] = nbest[k]->yy;
304          yp[k] = nbest[k]->yp;
305       }
306       /* Swap ny/iny with y/iy */
307       for (k=0;k<Lupdate;k++)
308       {
309          celt_norm_t *tmp_ny;
310          int *tmp_iny;
311
312          tmp_ny = ny[k];
313          ny[k] = y[k];
314          y[k] = tmp_ny;
315          tmp_iny = iny[k];
316          iny[k] = iy[k];
317          iy[k] = tmp_iny;
318       }
319       pulsesLeft -= pulsesAtOnce;
320    }
321    
322 #if 0
323    if (0) {
324       celt_word32_t err=0;
325       for (i=0;i<N;i++)
326          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
327       /*if (N<=10)
328         printf ("%f %d %d\n", err, K, N);*/
329    }
330    /* Sanity checks, don't bother */
331    if (0) {
332       for (i=0;i<N;i++)
333          x[i] = p[i]+nbest[0]->gain*y[0][i];
334       celt_word32_t E=1e-15;
335       int ABS = 0;
336       for (i=0;i<N;i++)
337          ABS += abs(iy[0][i]);
338       /*if (K != ABS)
339          printf ("%d %d\n", K, ABS);*/
340       for (i=0;i<N;i++)
341          E += x[i]*x[i];
342       /*printf ("%f\n", E);*/
343       E = 1/sqrt(E);
344       for (i=0;i<N;i++)
345          x[i] *= E;
346    }
347 #endif
348    
349    encode_pulses(iy[0], N, K, enc);
350    
351    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
352       due to the recursive computation used in quantisation.
353       Not quite sure whether we need that or not */
354    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
355    RESTORE_STACK;
356 }
357
358 /** Decode pulse vector and combine the result with the pitch vector to produce
359     the final normalised signal in the current band. */
360 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
361 {
362    VARDECL(int *iy);
363    SAVE_STACK;
364    ALLOC(iy, N, int);
365    decode_pulses(iy, N, K, dec);
366    mix_pitch_and_residual(iy, X, N, K, P, alpha);
367    RESTORE_STACK;
368 }
369
370 #ifdef FIXED_POINT
371 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
372 #else
373 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
374 #endif
375
376 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
377 {
378    int i,j;
379    int best=0;
380    celt_word32_t best_score=0;
381    celt_word16_t s = 1;
382    int sign;
383    celt_word32_t E;
384    celt_word16_t pred_gain;
385    int max_pos = N0-N/B;
386    if (max_pos > 32)
387       max_pos = 32;
388
389    for (i=0;i<max_pos*B;i+=B)
390    {
391       int j;
392       celt_word32_t xy=0, yy=0;
393       celt_word32_t score;
394       for (j=0;j<N;j++)
395       {
396          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
397          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
398       }
399       score = DIV32(MULT16_16(ROUND(xy,14),ROUND(xy,14)), ROUND(yy,14));
400       if (score > best_score)
401       {
402          best_score = score;
403          best = i;
404          if (xy>0)
405             s = 1;
406          else
407             s = -1;
408       }
409    }
410    if (s<0)
411       sign = 1;
412    else
413       sign = 0;
414    /*printf ("%d %d ", sign, best);*/
415    ec_enc_uint(enc,sign,2);
416    ec_enc_uint(enc,best/B,max_pos);
417    /*printf ("%d %f\n", best, best_score);*/
418    
419    if (K>10)
420       pred_gain = pg[10];
421    else
422       pred_gain = pg[K];
423    E = EPSILON;
424    for (j=0;j<N;j++)
425    {
426       P[j] = s*Y[best+N-j-1];
427       E = MAC16_16(E, P[j],P[j]);
428    }
429    /*pred_gain = pred_gain/sqrt(E);*/
430    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
431    for (j=0;j<N;j++)
432       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
433    if (K>0)
434    {
435       for (j=0;j<N;j++)
436          x[j] -= P[j];
437    } else {
438       for (j=0;j<N;j++)
439          x[j] = P[j];
440    }
441    /*printf ("quant ");*/
442    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
443
444 }
445
446 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
447 {
448    int j;
449    int sign;
450    celt_word16_t s;
451    int best;
452    celt_word32_t E;
453    celt_word16_t pred_gain;
454    int max_pos = N0-N/B;
455    if (max_pos > 32)
456       max_pos = 32;
457    
458    sign = ec_dec_uint(dec, 2);
459    if (sign == 0)
460       s = 1;
461    else
462       s = -1;
463    
464    best = B*ec_dec_uint(dec, max_pos);
465    /*printf ("%d %d ", sign, best);*/
466
467    if (K>10)
468       pred_gain = pg[10];
469    else
470       pred_gain = pg[K];
471    E = EPSILON;
472    for (j=0;j<N;j++)
473    {
474       P[j] = s*Y[best+N-j-1];
475       E = MAC16_16(E, P[j],P[j]);
476    }
477    /*pred_gain = pred_gain/sqrt(E);*/
478    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
479    for (j=0;j<N;j++)
480       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
481    if (K==0)
482    {
483       for (j=0;j<N;j++)
484          x[j] = P[j];
485    }
486 }
487
488 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
489 {
490    int i, j;
491    celt_word32_t E;
492    celt_word16_t g;
493    
494    E = EPSILON;
495    if (N0 >= Nmax/2)
496    {
497       for (i=0;i<B;i++)
498       {
499          for (j=0;j<N/B;j++)
500          {
501             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
502             E += P[j*B+i]*P[j*B+i];
503          }
504       }
505    } else {
506       for (j=0;j<N;j++)
507       {
508          P[j] = Y[j];
509          E = MAC16_16(E, P[j],P[j]);
510       }
511    }
512    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
513    for (j=0;j<N;j++)
514       P[j] = PSHR32(MULT16_16(g, P[j]),8);
515    for (j=0;j<N;j++)
516       x[j] = P[j];
517 }
518