fixed-point: intra_fold() converted
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift;
74 #endif
75    SAVE_STACK;
76 #ifdef FIXED_POINT
77    yshift = 14-EC_ILOG(K);
78 #endif
79    ALLOC(y, N, celt_norm_t);
80
81    /*for (i=0;i<N;i++)
82    printf ("%d ", iy[i]);*/
83    Rpp = 0;
84    for (i=0;i<N;i++)
85       Rpp = MAC16_16(Rpp,P[i],P[i]);
86
87    Ryp = 0;
88    for (i=0;i<N;i++)
89       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
90
91    /* Remove part of the pitch component to compute the real residual from
92       the encoded (int) one */
93    for (i=0;i<N;i++)
94       y[i] = SUB16(SHL16(iy[i],yshift),
95                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
96
97    /* Recompute after the projection (I think it's right) */
98    Ryp = 0;
99    for (i=0;i<N;i++)
100       Ryp = MAC16_16(Ryp,y[i],P[i]);
101
102    Ryy = 0;
103    for (i=0;i<N;i++)
104       Ryy = MAC16_16(Ryy, y[i],y[i]);
105
106    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
107    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
108
109    for (i=0;i<N;i++)
110       X[i] = P[i] + MULT16_32_Q14(y[i], g);
111    RESTORE_STACK;
112 }
113
114 /** All the info necessary to keep track of a hypothesis during the search */
115 struct NBest {
116    celt_word32_t score;
117    int sign;
118    int pos;
119    int orig;
120    celt_word32_t xy;
121    celt_word32_t yy;
122    celt_word32_t yp;
123 };
124
125 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
126 {
127    int L = 3;
128    VARDECL(celt_norm_t *_y);
129    VARDECL(celt_norm_t *_ny);
130    VARDECL(int *_iy);
131    VARDECL(int *_iny);
132    VARDECL(celt_norm_t **y);
133    VARDECL(celt_norm_t **ny);
134    VARDECL(int **iy);
135    VARDECL(int **iny);
136    int i, j, k, m;
137    int pulsesLeft;
138    VARDECL(celt_word32_t *xy);
139    VARDECL(celt_word32_t *yy);
140    VARDECL(celt_word32_t *yp);
141    VARDECL(struct NBest *_nbest);
142    VARDECL(struct NBest **nbest);
143    celt_word32_t Rpp=0, Rxp=0;
144    int maxL = 1;
145 #ifdef FIXED_POINT
146    int yshift;
147 #endif
148    SAVE_STACK;
149
150 #ifdef FIXED_POINT
151    yshift = 14-EC_ILOG(K);
152 #endif
153
154    ALLOC(_y, L*N, celt_norm_t);
155    ALLOC(_ny, L*N, celt_norm_t);
156    ALLOC(_iy, L*N, int);
157    ALLOC(_iny, L*N, int);
158    ALLOC(y, L, celt_norm_t*);
159    ALLOC(ny, L, celt_norm_t*);
160    ALLOC(iy, L, int*);
161    ALLOC(iny, L, int*);
162    
163    ALLOC(xy, L, celt_word32_t);
164    ALLOC(yy, L, celt_word32_t);
165    ALLOC(yp, L, celt_word32_t);
166    ALLOC(_nbest, L, struct NBest);
167    ALLOC(nbest, L, struct NBest *);
168    
169    for (m=0;m<L;m++)
170       nbest[m] = &_nbest[m];
171    
172    for (m=0;m<L;m++)
173    {
174       ny[m] = &_ny[m*N];
175       iny[m] = &_iny[m*N];
176       y[m] = &_y[m*N];
177       iy[m] = &_iy[m*N];
178    }
179    
180    for (j=0;j<N;j++)
181    {
182       Rpp = MAC16_16(Rpp, P[j],P[j]);
183       Rxp = MAC16_16(Rxp, X[j],P[j]);
184    }
185    Rpp = ROUND(Rpp, NORM_SHIFT);
186    Rxp = ROUND(Rxp, NORM_SHIFT);
187    if (Rpp>NORM_SCALING)
188       celt_fatal("Rpp > 1");
189
190    /* We only need to initialise the zero because the first iteration only uses that */
191    for (i=0;i<N;i++)
192       y[0][i] = 0;
193    for (i=0;i<N;i++)
194       iy[0][i] = 0;
195    xy[0] = yy[0] = yp[0] = 0;
196
197    pulsesLeft = K;
198    while (pulsesLeft > 0)
199    {
200       int pulsesAtOnce=1;
201       int Lupdate = L;
202       int L2 = L;
203       
204       /* Decide on complexity strategy */
205       pulsesAtOnce = pulsesLeft/N;
206       if (pulsesAtOnce<1)
207          pulsesAtOnce = 1;
208       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
209          Lupdate = 1;
210       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
211       L2 = Lupdate;
212       if (L2>maxL)
213       {
214          L2 = maxL;
215          maxL *= N;
216       }
217
218       for (m=0;m<Lupdate;m++)
219          nbest[m]->score = -VERY_LARGE32;
220
221       for (m=0;m<L2;m++)
222       {
223          for (j=0;j<N;j++)
224          {
225             int sign;
226             /*if (x[j]>0) sign=1; else sign=-1;*/
227             for (sign=-1;sign<=1;sign+=2)
228             {
229                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
230                celt_word32_t Rxy, Ryy, Ryp;
231                celt_word16_t spj, aspj; /* Intermediate results */
232                celt_word32_t score;
233                celt_word32_t g;
234                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
235                
236                /* All pulses at one location must have the same sign. */
237                if (iy[m][j]*sign < 0)
238                   continue;
239
240                spj = MULT16_16_P14(s, P[j]);
241                aspj = MULT16_16_P15(alpha, spj);
242                /* Updating the sums of the new pulse(s) */
243                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
244                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
245                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
246                
247                /* Compute the gain such that ||p + g*y|| = 1 */
248                g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),Rpp)) - ROUND(Ryp,14),14),ROUND(Ryy,14));
249                
250                /* Knowing that gain, what the error: (x-g*y)^2 
251                   (result is negated and we discard x^2 because it's constant) */
252                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
253                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g) -
254                      MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
255
256                if (score>nbest[Lupdate-1]->score)
257                {
258                   int k;
259                   int id = Lupdate-1;
260                   struct NBest *tmp_best;
261
262                   /* Save some pointers that would be deleted and use them for the current entry*/
263                   tmp_best = nbest[Lupdate-1];
264                   while (id > 0 && score > nbest[id-1]->score)
265                      id--;
266                
267                   for (k=Lupdate-1;k>id;k--)
268                      nbest[k] = nbest[k-1];
269
270                   nbest[id] = tmp_best;
271                   nbest[id]->score = score;
272                   nbest[id]->pos = j;
273                   nbest[id]->orig = m;
274                   nbest[id]->sign = sign;
275                   nbest[id]->xy = Rxy;
276                   nbest[id]->yy = Ryy;
277                   nbest[id]->yp = Ryp;
278                }
279             }
280          }
281
282       }
283       
284       if (!(nbest[0]->score > -VERY_LARGE32))
285          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
286       /* Only now that we've made the final choice, update ny/iny and others */
287       for (k=0;k<Lupdate;k++)
288       {
289          int n;
290          int is;
291          celt_norm_t s;
292          is = nbest[k]->sign*pulsesAtOnce;
293          s = SHL16(is, yshift);
294          for (n=0;n<N;n++)
295             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
296          ny[k][nbest[k]->pos] += s;
297
298          for (n=0;n<N;n++)
299             iny[k][n] = iy[nbest[k]->orig][n];
300          iny[k][nbest[k]->pos] += is;
301
302          xy[k] = nbest[k]->xy;
303          yy[k] = nbest[k]->yy;
304          yp[k] = nbest[k]->yp;
305       }
306       /* Swap ny/iny with y/iy */
307       for (k=0;k<Lupdate;k++)
308       {
309          celt_norm_t *tmp_ny;
310          int *tmp_iny;
311
312          tmp_ny = ny[k];
313          ny[k] = y[k];
314          y[k] = tmp_ny;
315          tmp_iny = iny[k];
316          iny[k] = iy[k];
317          iy[k] = tmp_iny;
318       }
319       pulsesLeft -= pulsesAtOnce;
320    }
321    
322 #if 0
323    if (0) {
324       celt_word32_t err=0;
325       for (i=0;i<N;i++)
326          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
327       /*if (N<=10)
328         printf ("%f %d %d\n", err, K, N);*/
329    }
330    /* Sanity checks, don't bother */
331    if (0) {
332       for (i=0;i<N;i++)
333          x[i] = p[i]+nbest[0]->gain*y[0][i];
334       celt_word32_t E=1e-15;
335       int ABS = 0;
336       for (i=0;i<N;i++)
337          ABS += abs(iy[0][i]);
338       /*if (K != ABS)
339          printf ("%d %d\n", K, ABS);*/
340       for (i=0;i<N;i++)
341          E += x[i]*x[i];
342       /*printf ("%f\n", E);*/
343       E = 1/sqrt(E);
344       for (i=0;i<N;i++)
345          x[i] *= E;
346    }
347 #endif
348    
349    encode_pulses(iy[0], N, K, enc);
350    
351    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
352       due to the recursive computation used in quantisation.
353       Not quite sure whether we need that or not */
354    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
355    RESTORE_STACK;
356 }
357
358 /** Decode pulse vector and combine the result with the pitch vector to produce
359     the final normalised signal in the current band. */
360 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
361 {
362    VARDECL(int *iy);
363    SAVE_STACK;
364    ALLOC(iy, N, int);
365    decode_pulses(iy, N, K, dec);
366    mix_pitch_and_residual(iy, X, N, K, P, alpha);
367    RESTORE_STACK;
368 }
369
370
371 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
372
373 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
374 {
375    int i,j;
376    int best=0;
377    celt_word32_t best_score=0;
378    celt_word16_t s = 1;
379    int sign;
380    celt_word32_t E;
381    float pred_gain;
382    int max_pos = N0-N/B;
383    if (max_pos > 32)
384       max_pos = 32;
385
386    for (i=0;i<max_pos*B;i+=B)
387    {
388       int j;
389       celt_word32_t xy=0, yy=0;
390       float score;
391       for (j=0;j<N;j++)
392       {
393          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
394          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
395       }
396       score = 1.f*xy*xy/(.001+yy);
397       if (score > best_score)
398       {
399          best_score = score;
400          best = i;
401          if (xy>0)
402             s = 1;
403          else
404             s = -1;
405       }
406    }
407    if (s<0)
408       sign = 1;
409    else
410       sign = 0;
411    /*printf ("%d %d ", sign, best);*/
412    ec_enc_uint(enc,sign,2);
413    ec_enc_uint(enc,best/B,max_pos);
414    /*printf ("%d %f\n", best, best_score);*/
415    
416    if (K>10)
417       pred_gain = pg[10];
418    else
419       pred_gain = pg[K];
420    E = 1e-10;
421    for (j=0;j<N;j++)
422    {
423       P[j] = s*Y[best+N-j-1];
424       E = MAC16_16(E, P[j],P[j]);
425    }
426    pred_gain = NORM_SCALING*pred_gain/sqrt(E);
427    for (j=0;j<N;j++)
428       P[j] *= pred_gain;
429    if (K>0)
430    {
431       for (j=0;j<N;j++)
432          x[j] -= P[j];
433    } else {
434       for (j=0;j<N;j++)
435          x[j] = P[j];
436    }
437    /*printf ("quant ");*/
438    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
439
440 }
441
442 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
443 {
444    int j;
445    int sign;
446    celt_word16_t s;
447    int best;
448    celt_word32_t E;
449    float pred_gain;
450    int max_pos = N0-N/B;
451    if (max_pos > 32)
452       max_pos = 32;
453    
454    sign = ec_dec_uint(dec, 2);
455    if (sign == 0)
456       s = 1;
457    else
458       s = -1;
459    
460    best = B*ec_dec_uint(dec, max_pos);
461    /*printf ("%d %d ", sign, best);*/
462
463    if (K>10)
464       pred_gain = pg[10];
465    else
466       pred_gain = pg[K];
467    E = 1e-10;
468    for (j=0;j<N;j++)
469    {
470       P[j] = s*Y[best+N-j-1];
471       E = MAC16_16(E, P[j],P[j]);
472    }
473    pred_gain = NORM_SCALING*pred_gain/sqrt(E);
474    for (j=0;j<N;j++)
475       P[j] *= pred_gain;
476    if (K==0)
477    {
478       for (j=0;j<N;j++)
479          x[j] = P[j];
480    }
481 }
482
483 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
484 {
485    int i, j;
486    celt_word32_t E;
487    celt_word16_t g;
488    
489    E = EPSILON;
490    if (N0 >= Nmax/2)
491    {
492       for (i=0;i<B;i++)
493       {
494          for (j=0;j<N/B;j++)
495          {
496             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
497             E += P[j*B+i]*P[j*B+i];
498          }
499       }
500    } else {
501       for (j=0;j<N;j++)
502       {
503          P[j] = Y[j];
504          E = MAC16_16(E, P[j],P[j]);
505       }
506    }
507    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
508    for (j=0;j<N;j++)
509       P[j] = PSHR32(MULT16_16(g, P[j]),8);
510    for (j=0;j<N;j++)
511       x[j] = P[j];
512 }
513