fixed-point: Added a ROUND() operator, no real change to the code
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift = 14-EC_ILOG(K);
74 #endif
75    ALLOC(y, N, celt_norm_t);
76
77    /*for (i=0;i<N;i++)
78    printf ("%d ", iy[i]);*/
79    Rpp = 0;
80    for (i=0;i<N;i++)
81       Rpp = MAC16_16(Rpp,P[i],P[i]);
82
83    Ryp = 0;
84    for (i=0;i<N;i++)
85       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
86
87    /* Remove part of the pitch component to compute the real residual from
88       the encoded (int) one */
89    for (i=0;i<N;i++)
90       y[i] = SUB16(SHL16(iy[i],yshift),
91                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
92
93    /* Recompute after the projection (I think it's right) */
94    Ryp = 0;
95    for (i=0;i<N;i++)
96       Ryp = MAC16_16(Ryp,y[i],P[i]);
97
98    Ryy = 0;
99    for (i=0;i<N;i++)
100       Ryy = MAC16_16(Ryy, y[i],y[i]);
101
102    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
103    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
104
105    for (i=0;i<N;i++)
106       X[i] = P[i] + MULT16_32_Q14(y[i], g);
107 }
108
109 /** All the info necessary to keep track of a hypothesis during the search */
110 struct NBest {
111    float score;
112    float gain;
113    int sign;
114    int pos;
115    int orig;
116    float xy;
117    float yy;
118    float yp;
119 };
120
121 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
122 {
123    int L = 3;
124    VARDECL(float *x);
125    VARDECL(float *p);
126    VARDECL(float *_y);
127    VARDECL(float *_ny);
128    VARDECL(int *_iy);
129    VARDECL(int *_iny);
130    VARDECL(float **y);
131    VARDECL(float **ny);
132    VARDECL(int **iy);
133    VARDECL(int **iny);
134    int i, j, k, m;
135    int pulsesLeft;
136    VARDECL(float *xy);
137    VARDECL(float *yy);
138    VARDECL(float *yp);
139    VARDECL(struct NBest *_nbest);
140    VARDECL(struct NBest **nbest);
141    float Rpp=0, Rxp=0;
142    int maxL = 1;
143    float _alpha = Q15_ONE_1*alpha;
144
145    ALLOC(x, N, float);
146    ALLOC(p, N, float);
147    ALLOC(_y, L*N, float);
148    ALLOC(_ny, L*N, float);
149    ALLOC(_iy, L*N, int);
150    ALLOC(_iny, L*N, int);
151    ALLOC(y, L*N, float*);
152    ALLOC(ny, L*N, float*);
153    ALLOC(iy, L*N, int*);
154    ALLOC(iny, L*N, int*);
155    
156    ALLOC(xy, L, float);
157    ALLOC(yy, L, float);
158    ALLOC(yp, L, float);
159    ALLOC(_nbest, L, struct NBest);
160    ALLOC(nbest, L, struct NBest *);
161
162    for (j=0;j<N;j++)
163    {
164       x[j] = X[j]*NORM_SCALING_1;
165       p[j] = P[j]*NORM_SCALING_1;
166    }
167    
168    for (m=0;m<L;m++)
169       nbest[m] = &_nbest[m];
170    
171    for (m=0;m<L;m++)
172    {
173       ny[m] = &_ny[m*N];
174       iny[m] = &_iny[m*N];
175       y[m] = &_y[m*N];
176       iy[m] = &_iy[m*N];
177    }
178    
179    for (j=0;j<N;j++)
180    {
181       Rpp += p[j]*p[j];
182       Rxp += x[j]*p[j];
183    }
184    if (Rpp>1)
185       celt_fatal("Rpp > 1");
186
187    /* We only need to initialise the zero because the first iteration only uses that */
188    for (i=0;i<N;i++)
189       y[0][i] = 0;
190    for (i=0;i<N;i++)
191       iy[0][i] = 0;
192    xy[0] = yy[0] = yp[0] = 0;
193
194    pulsesLeft = K;
195    while (pulsesLeft > 0)
196    {
197       int pulsesAtOnce=1;
198       int Lupdate = L;
199       int L2 = L;
200       
201       /* Decide on complexity strategy */
202       pulsesAtOnce = pulsesLeft/N;
203       if (pulsesAtOnce<1)
204          pulsesAtOnce = 1;
205       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
206          Lupdate = 1;
207       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
208       L2 = Lupdate;
209       if (L2>maxL)
210       {
211          L2 = maxL;
212          maxL *= N;
213       }
214
215       for (m=0;m<Lupdate;m++)
216          nbest[m]->score = -1e10f;
217
218       for (m=0;m<L2;m++)
219       {
220          for (j=0;j<N;j++)
221          {
222             int sign;
223             /*if (x[j]>0) sign=1; else sign=-1;*/
224             for (sign=-1;sign<=1;sign+=2)
225             {
226                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
227                float tmp_xy, tmp_yy, tmp_yp;
228                float score;
229                float g;
230                float s = sign*pulsesAtOnce;
231                
232                /* All pulses at one location must have the same sign. */
233                if (iy[m][j]*sign < 0)
234                   continue;
235
236                /* Updating the sums of the new pulse(s) */
237                tmp_xy = xy[m] + s*x[j]               - _alpha*s*p[j]*Rxp;
238                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*_alpha*_alpha*p[j]*p[j]*Rpp - 2.f*_alpha*s*p[j]*yp[m] - 2.f*s*s*_alpha*p[j]*p[j];
239                tmp_yp = yp[m] + s*p[j]               *(1.f-_alpha*Rpp);
240                
241                /* Compute the gain such that ||p + g*y|| = 1 */
242                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
243                /* Knowing that gain, what the error: (x-g*y)^2 
244                   (result is negated and we discard x^2 because it's constant) */
245                score = 2.f*g*tmp_xy - g*g*tmp_yy;
246
247                if (score>nbest[Lupdate-1]->score)
248                {
249                   int k;
250                   int id = Lupdate-1;
251                   struct NBest *tmp_best;
252
253                   /* Save some pointers that would be deleted and use them for the current entry*/
254                   tmp_best = nbest[Lupdate-1];
255                   while (id > 0 && score > nbest[id-1]->score)
256                      id--;
257                
258                   for (k=Lupdate-1;k>id;k--)
259                      nbest[k] = nbest[k-1];
260
261                   nbest[id] = tmp_best;
262                   nbest[id]->score = score;
263                   nbest[id]->pos = j;
264                   nbest[id]->orig = m;
265                   nbest[id]->sign = sign;
266                   nbest[id]->gain = g;
267                   nbest[id]->xy = tmp_xy;
268                   nbest[id]->yy = tmp_yy;
269                   nbest[id]->yp = tmp_yp;
270                }
271             }
272          }
273
274       }
275       
276       if (!(nbest[0]->score > -1e10f))
277          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
278       /* Only now that we've made the final choice, update ny/iny and others */
279       for (k=0;k<Lupdate;k++)
280       {
281          int n;
282          int is;
283          float s;
284          is = nbest[k]->sign*pulsesAtOnce;
285          s = is;
286          for (n=0;n<N;n++)
287             ny[k][n] = y[nbest[k]->orig][n] - _alpha*s*p[nbest[k]->pos]*p[n];
288          ny[k][nbest[k]->pos] += s;
289
290          for (n=0;n<N;n++)
291             iny[k][n] = iy[nbest[k]->orig][n];
292          iny[k][nbest[k]->pos] += is;
293
294          xy[k] = nbest[k]->xy;
295          yy[k] = nbest[k]->yy;
296          yp[k] = nbest[k]->yp;
297       }
298       /* Swap ny/iny with y/iy */
299       for (k=0;k<Lupdate;k++)
300       {
301          float *tmp_ny;
302          int *tmp_iny;
303
304          tmp_ny = ny[k];
305          ny[k] = y[k];
306          y[k] = tmp_ny;
307          tmp_iny = iny[k];
308          iny[k] = iy[k];
309          iy[k] = tmp_iny;
310       }
311       pulsesLeft -= pulsesAtOnce;
312    }
313    
314 #if 0
315    if (0) {
316       float err=0;
317       for (i=0;i<N;i++)
318          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
319       /*if (N<=10)
320         printf ("%f %d %d\n", err, K, N);*/
321    }
322    /* Sanity checks, don't bother */
323    if (0) {
324       for (i=0;i<N;i++)
325          x[i] = p[i]+nbest[0]->gain*y[0][i];
326       float E=1e-15;
327       int ABS = 0;
328       for (i=0;i<N;i++)
329          ABS += abs(iy[0][i]);
330       /*if (K != ABS)
331          printf ("%d %d\n", K, ABS);*/
332       for (i=0;i<N;i++)
333          E += x[i]*x[i];
334       /*printf ("%f\n", E);*/
335       E = 1/sqrt(E);
336       for (i=0;i<N;i++)
337          x[i] *= E;
338    }
339 #endif
340    
341    encode_pulses(iy[0], N, K, enc);
342    
343    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
344       due to the recursive computation used in quantisation.
345       Not quite sure whether we need that or not */
346    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
347 }
348
349 /** Decode pulse vector and combine the result with the pitch vector to produce
350     the final normalised signal in the current band. */
351 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
352 {
353    VARDECL(int *iy);
354    ALLOC(iy, N, int);
355    decode_pulses(iy, N, K, dec);
356    mix_pitch_and_residual(iy, X, N, K, P, alpha);
357 }
358
359
360 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
361
362 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
363 {
364    int i,j;
365    int best=0;
366    float best_score=0;
367    float s = 1;
368    int sign;
369    float E;
370    float pred_gain;
371    int max_pos = N0-N/B;
372    if (max_pos > 32)
373       max_pos = 32;
374
375    for (i=0;i<max_pos*B;i+=B)
376    {
377       int j;
378       float xy=0, yy=0;
379       float score;
380       for (j=0;j<N;j++)
381       {
382          xy += 1.f*x[j]*Y[i+N-j-1];
383          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
384       }
385       score = xy*xy/(.001+yy);
386       if (score > best_score)
387       {
388          best_score = score;
389          best = i;
390          if (xy>0)
391             s = 1;
392          else
393             s = -1;
394       }
395    }
396    if (s<0)
397       sign = 1;
398    else
399       sign = 0;
400    /*printf ("%d %d ", sign, best);*/
401    ec_enc_uint(enc,sign,2);
402    ec_enc_uint(enc,best/B,max_pos);
403    /*printf ("%d %f\n", best, best_score);*/
404    
405    if (K>10)
406       pred_gain = pg[10];
407    else
408       pred_gain = pg[K];
409    E = 1e-10;
410    for (j=0;j<N;j++)
411    {
412       P[j] = s*Y[best+N-j-1];
413       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
414    }
415    E = pred_gain/sqrt(E);
416    for (j=0;j<N;j++)
417       P[j] *= E;
418    if (K>0)
419    {
420       for (j=0;j<N;j++)
421          x[j] -= P[j];
422    } else {
423       for (j=0;j<N;j++)
424          x[j] = P[j];
425    }
426    /*printf ("quant ");*/
427    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
428
429 }
430
431 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
432 {
433    int j;
434    int sign;
435    float s;
436    int best;
437    float E;
438    float pred_gain;
439    int max_pos = N0-N/B;
440    if (max_pos > 32)
441       max_pos = 32;
442    
443    sign = ec_dec_uint(dec, 2);
444    if (sign == 0)
445       s = 1;
446    else
447       s = -1;
448    
449    best = B*ec_dec_uint(dec, max_pos);
450    /*printf ("%d %d ", sign, best);*/
451
452    if (K>10)
453       pred_gain = pg[10];
454    else
455       pred_gain = pg[K];
456    E = 1e-10;
457    for (j=0;j<N;j++)
458    {
459       P[j] = s*Y[best+N-j-1];
460       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
461    }
462    E = pred_gain/sqrt(E);
463    for (j=0;j<N;j++)
464       P[j] *= E;
465    if (K==0)
466    {
467       for (j=0;j<N;j++)
468          x[j] = P[j];
469    }
470 }
471
472 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
473 {
474    int i, j;
475    float E;
476    
477    E = 1e-10;
478    if (N0 >= Nmax/2)
479    {
480       for (i=0;i<B;i++)
481       {
482          for (j=0;j<N/B;j++)
483          {
484             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
485             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
486          }
487       }
488    } else {
489       for (j=0;j<N;j++)
490       {
491          P[j] = Y[j];
492          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
493       }
494    }
495    E = 1.f/sqrt(E);
496    for (j=0;j<N;j++)
497       P[j] *= E;
498    for (j=0;j<N;j++)
499       x[j] = P[j];
500 }
501