fixed-point: unquant_energy_mono() has received the fixed-point code from
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift = 14-EC_ILOG(K);
74 #endif
75    ALLOC(y, N, celt_norm_t);
76
77    /*for (i=0;i<N;i++)
78    printf ("%d ", iy[i]);*/
79    Rpp = 0;
80    for (i=0;i<N;i++)
81       Rpp = MAC16_16(Rpp,P[i],P[i]);
82
83    Ryp = 0;
84    for (i=0;i<N;i++)
85       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
86
87    /* Remove part of the pitch component to compute the real residual from
88       the encoded (int) one */
89    for (i=0;i<N;i++)
90       y[i] = SUB16(SHL16(iy[i],yshift),
91                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
92
93    /* Recompute after the projection (I think it's right) */
94    Ryp = 0;
95    for (i=0;i<N;i++)
96       Ryp = MAC16_16(Ryp,y[i],P[i]);
97
98    Ryy = 0;
99    for (i=0;i<N;i++)
100       Ryy = MAC16_16(Ryy, y[i],y[i]);
101
102    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
103    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
104
105    for (i=0;i<N;i++)
106       X[i] = P[i] + MULT16_32_Q14(y[i], g);
107 }
108
109 /** All the info necessary to keep track of a hypothesis during the search */
110 struct NBest {
111    celt_word32_t score;
112    int sign;
113    int pos;
114    int orig;
115    celt_word32_t xy;
116    celt_word32_t yy;
117    celt_word32_t yp;
118 };
119
120 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
121 {
122    int L = 3;
123    VARDECL(celt_norm_t *_y);
124    VARDECL(celt_norm_t *_ny);
125    VARDECL(int *_iy);
126    VARDECL(int *_iny);
127    VARDECL(celt_norm_t **y);
128    VARDECL(celt_norm_t **ny);
129    VARDECL(int **iy);
130    VARDECL(int **iny);
131    int i, j, k, m;
132    int pulsesLeft;
133    VARDECL(celt_word32_t *xy);
134    VARDECL(celt_word32_t *yy);
135    VARDECL(celt_word32_t *yp);
136    VARDECL(struct NBest *_nbest);
137    VARDECL(struct NBest **nbest);
138    celt_word32_t Rpp=0, Rxp=0;
139    int maxL = 1;
140 #ifdef FIXED_POINT
141    int yshift = 14-EC_ILOG(K);
142 #endif
143
144    ALLOC(_y, L*N, celt_norm_t);
145    ALLOC(_ny, L*N, celt_norm_t);
146    ALLOC(_iy, L*N, int);
147    ALLOC(_iny, L*N, int);
148    ALLOC(y, L*N, celt_norm_t*);
149    ALLOC(ny, L*N, celt_norm_t*);
150    ALLOC(iy, L*N, int*);
151    ALLOC(iny, L*N, int*);
152    
153    ALLOC(xy, L, celt_word32_t);
154    ALLOC(yy, L, celt_word32_t);
155    ALLOC(yp, L, celt_word32_t);
156    ALLOC(_nbest, L, struct NBest);
157    ALLOC(nbest, L, struct NBest *);
158    
159    for (m=0;m<L;m++)
160       nbest[m] = &_nbest[m];
161    
162    for (m=0;m<L;m++)
163    {
164       ny[m] = &_ny[m*N];
165       iny[m] = &_iny[m*N];
166       y[m] = &_y[m*N];
167       iy[m] = &_iy[m*N];
168    }
169    
170    for (j=0;j<N;j++)
171    {
172       Rpp = MAC16_16(Rpp, P[j],P[j]);
173       Rxp = MAC16_16(Rxp, X[j],P[j]);
174    }
175    Rpp = ROUND(Rpp, NORM_SHIFT);
176    Rxp = ROUND(Rxp, NORM_SHIFT);
177    if (Rpp>NORM_SCALING)
178       celt_fatal("Rpp > 1");
179
180    /* We only need to initialise the zero because the first iteration only uses that */
181    for (i=0;i<N;i++)
182       y[0][i] = 0;
183    for (i=0;i<N;i++)
184       iy[0][i] = 0;
185    xy[0] = yy[0] = yp[0] = 0;
186
187    pulsesLeft = K;
188    while (pulsesLeft > 0)
189    {
190       int pulsesAtOnce=1;
191       int Lupdate = L;
192       int L2 = L;
193       
194       /* Decide on complexity strategy */
195       pulsesAtOnce = pulsesLeft/N;
196       if (pulsesAtOnce<1)
197          pulsesAtOnce = 1;
198       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
199          Lupdate = 1;
200       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
201       L2 = Lupdate;
202       if (L2>maxL)
203       {
204          L2 = maxL;
205          maxL *= N;
206       }
207
208       for (m=0;m<Lupdate;m++)
209          nbest[m]->score = -VERY_LARGE32;
210
211       for (m=0;m<L2;m++)
212       {
213          for (j=0;j<N;j++)
214          {
215             int sign;
216             /*if (x[j]>0) sign=1; else sign=-1;*/
217             for (sign=-1;sign<=1;sign+=2)
218             {
219                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
220                celt_word32_t Rxy, Ryy, Ryp;
221                celt_word16_t spj, aspj; /* Intermediate results */
222                celt_word32_t score;
223                celt_word32_t g;
224                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
225                
226                /* All pulses at one location must have the same sign. */
227                if (iy[m][j]*sign < 0)
228                   continue;
229
230                spj = MULT16_16_P14(s, P[j]);
231                aspj = MULT16_16_P15(alpha, spj);
232                /* Updating the sums of the new pulse(s) */
233                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
234                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
235                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
236                
237                /* Compute the gain such that ||p + g*y|| = 1 */
238                g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),Rpp)) - ROUND(Ryp,14),14),ROUND(Ryy,14));
239                //g *= NORM_SCALING_1;
240                /* Knowing that gain, what the error: (x-g*y)^2 
241                   (result is negated and we discard x^2 because it's constant) */
242                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
243                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g) -
244                      MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
245
246                if (score>nbest[Lupdate-1]->score)
247                {
248                   int k;
249                   int id = Lupdate-1;
250                   struct NBest *tmp_best;
251
252                   /* Save some pointers that would be deleted and use them for the current entry*/
253                   tmp_best = nbest[Lupdate-1];
254                   while (id > 0 && score > nbest[id-1]->score)
255                      id--;
256                
257                   for (k=Lupdate-1;k>id;k--)
258                      nbest[k] = nbest[k-1];
259
260                   nbest[id] = tmp_best;
261                   nbest[id]->score = score;
262                   nbest[id]->pos = j;
263                   nbest[id]->orig = m;
264                   nbest[id]->sign = sign;
265                   nbest[id]->xy = Rxy;
266                   nbest[id]->yy = Ryy;
267                   nbest[id]->yp = Ryp;
268                }
269             }
270          }
271
272       }
273       
274       if (!(nbest[0]->score > -VERY_LARGE32))
275          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
276       /* Only now that we've made the final choice, update ny/iny and others */
277       for (k=0;k<Lupdate;k++)
278       {
279          int n;
280          int is;
281          celt_norm_t s;
282          is = nbest[k]->sign*pulsesAtOnce;
283          s = SHL16(is, yshift);
284          for (n=0;n<N;n++)
285             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
286          ny[k][nbest[k]->pos] += s;
287
288          for (n=0;n<N;n++)
289             iny[k][n] = iy[nbest[k]->orig][n];
290          iny[k][nbest[k]->pos] += is;
291
292          xy[k] = nbest[k]->xy;
293          yy[k] = nbest[k]->yy;
294          yp[k] = nbest[k]->yp;
295       }
296       /* Swap ny/iny with y/iy */
297       for (k=0;k<Lupdate;k++)
298       {
299          celt_norm_t *tmp_ny;
300          int *tmp_iny;
301
302          tmp_ny = ny[k];
303          ny[k] = y[k];
304          y[k] = tmp_ny;
305          tmp_iny = iny[k];
306          iny[k] = iy[k];
307          iy[k] = tmp_iny;
308       }
309       pulsesLeft -= pulsesAtOnce;
310    }
311    
312 #if 0
313    if (0) {
314       float err=0;
315       for (i=0;i<N;i++)
316          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
317       /*if (N<=10)
318         printf ("%f %d %d\n", err, K, N);*/
319    }
320    /* Sanity checks, don't bother */
321    if (0) {
322       for (i=0;i<N;i++)
323          x[i] = p[i]+nbest[0]->gain*y[0][i];
324       float E=1e-15;
325       int ABS = 0;
326       for (i=0;i<N;i++)
327          ABS += abs(iy[0][i]);
328       /*if (K != ABS)
329          printf ("%d %d\n", K, ABS);*/
330       for (i=0;i<N;i++)
331          E += x[i]*x[i];
332       /*printf ("%f\n", E);*/
333       E = 1/sqrt(E);
334       for (i=0;i<N;i++)
335          x[i] *= E;
336    }
337 #endif
338    
339    encode_pulses(iy[0], N, K, enc);
340    
341    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
342       due to the recursive computation used in quantisation.
343       Not quite sure whether we need that or not */
344    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
345 }
346
347 /** Decode pulse vector and combine the result with the pitch vector to produce
348     the final normalised signal in the current band. */
349 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
350 {
351    VARDECL(int *iy);
352    ALLOC(iy, N, int);
353    decode_pulses(iy, N, K, dec);
354    mix_pitch_and_residual(iy, X, N, K, P, alpha);
355 }
356
357
358 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
359
360 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
361 {
362    int i,j;
363    int best=0;
364    float best_score=0;
365    float s = 1;
366    int sign;
367    float E;
368    float pred_gain;
369    int max_pos = N0-N/B;
370    if (max_pos > 32)
371       max_pos = 32;
372
373    for (i=0;i<max_pos*B;i+=B)
374    {
375       int j;
376       float xy=0, yy=0;
377       float score;
378       for (j=0;j<N;j++)
379       {
380          xy += 1.f*x[j]*Y[i+N-j-1];
381          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
382       }
383       score = xy*xy/(.001+yy);
384       if (score > best_score)
385       {
386          best_score = score;
387          best = i;
388          if (xy>0)
389             s = 1;
390          else
391             s = -1;
392       }
393    }
394    if (s<0)
395       sign = 1;
396    else
397       sign = 0;
398    /*printf ("%d %d ", sign, best);*/
399    ec_enc_uint(enc,sign,2);
400    ec_enc_uint(enc,best/B,max_pos);
401    /*printf ("%d %f\n", best, best_score);*/
402    
403    if (K>10)
404       pred_gain = pg[10];
405    else
406       pred_gain = pg[K];
407    E = 1e-10;
408    for (j=0;j<N;j++)
409    {
410       P[j] = s*Y[best+N-j-1];
411       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
412    }
413    E = pred_gain/sqrt(E);
414    for (j=0;j<N;j++)
415       P[j] *= E;
416    if (K>0)
417    {
418       for (j=0;j<N;j++)
419          x[j] -= P[j];
420    } else {
421       for (j=0;j<N;j++)
422          x[j] = P[j];
423    }
424    /*printf ("quant ");*/
425    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
426
427 }
428
429 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
430 {
431    int j;
432    int sign;
433    float s;
434    int best;
435    float E;
436    float pred_gain;
437    int max_pos = N0-N/B;
438    if (max_pos > 32)
439       max_pos = 32;
440    
441    sign = ec_dec_uint(dec, 2);
442    if (sign == 0)
443       s = 1;
444    else
445       s = -1;
446    
447    best = B*ec_dec_uint(dec, max_pos);
448    /*printf ("%d %d ", sign, best);*/
449
450    if (K>10)
451       pred_gain = pg[10];
452    else
453       pred_gain = pg[K];
454    E = 1e-10;
455    for (j=0;j<N;j++)
456    {
457       P[j] = s*Y[best+N-j-1];
458       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
459    }
460    E = pred_gain/sqrt(E);
461    for (j=0;j<N;j++)
462       P[j] *= E;
463    if (K==0)
464    {
465       for (j=0;j<N;j++)
466          x[j] = P[j];
467    }
468 }
469
470 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
471 {
472    int i, j;
473    float E;
474    
475    E = 1e-10;
476    if (N0 >= Nmax/2)
477    {
478       for (i=0;i<B;i++)
479       {
480          for (j=0;j<N/B;j++)
481          {
482             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
483             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
484          }
485       }
486    } else {
487       for (j=0;j<N;j++)
488       {
489          P[j] = Y[j];
490          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
491       }
492    }
493    E = 1.f/sqrt(E);
494    for (j=0;j<N;j++)
495       P[j] *= E;
496    for (j=0;j<N;j++)
497       x[j] = P[j];
498 }
499