fixed-point: celt_norm_t now a 16-bit value.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41
42 /* Enable this or define your own implementation if you want to speed up the
43    VQ search (used in inner loop only) */
44 #if 0
45 #include <xmmintrin.h>
46 static inline float approx_sqrt(float x)
47 {
48    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
49    return x;
50 }
51 static inline float approx_inv(float x)
52 {
53    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
54    return x;
55 }
56 #else
57 #define approx_sqrt(x) (sqrt(x))
58 #define approx_inv(x) (1.f/(x))
59 #endif
60
61 /** All the info necessary to keep track of a hypothesis during the search */
62 struct NBest {
63    float score;
64    float gain;
65    int sign;
66    int pos;
67    int orig;
68    float xy;
69    float yy;
70    float yp;
71 };
72
73 void alg_quant(celt_norm_t *X, float *W, int N, int K, celt_norm_t *P, float alpha, ec_enc *enc)
74 {
75    int L = 3;
76    VARDECL(float *x);
77    VARDECL(float *p);
78    VARDECL(float *_y);
79    VARDECL(float *_ny);
80    VARDECL(int *_iy);
81    VARDECL(int *_iny);
82    VARDECL(float **y);
83    VARDECL(float **ny);
84    VARDECL(int **iy);
85    VARDECL(int **iny);
86    int i, j, k, m;
87    int pulsesLeft;
88    VARDECL(float *xy);
89    VARDECL(float *yy);
90    VARDECL(float *yp);
91    VARDECL(struct NBest *_nbest);
92    VARDECL(struct NBest **nbest);
93    float Rpp=0, Rxp=0;
94    int maxL = 1;
95    
96    ALLOC(x, N, float);
97    ALLOC(p, N, float);
98    ALLOC(_y, L*N, float);
99    ALLOC(_ny, L*N, float);
100    ALLOC(_iy, L*N, int);
101    ALLOC(_iny, L*N, int);
102    ALLOC(y, L*N, float*);
103    ALLOC(ny, L*N, float*);
104    ALLOC(iy, L*N, int*);
105    ALLOC(iny, L*N, int*);
106    
107    ALLOC(xy, L, float);
108    ALLOC(yy, L, float);
109    ALLOC(yp, L, float);
110    ALLOC(_nbest, L, struct NBest);
111    ALLOC(nbest, L, struct NBest *);
112
113    for (j=0;j<N;j++)
114    {
115       x[j] = X[j]*NORM_SCALING_1;
116       p[j] = P[j]*NORM_SCALING_1;
117    }
118    
119    for (m=0;m<L;m++)
120       nbest[m] = &_nbest[m];
121    
122    for (m=0;m<L;m++)
123    {
124       ny[m] = &_ny[m*N];
125       iny[m] = &_iny[m*N];
126       y[m] = &_y[m*N];
127       iy[m] = &_iy[m*N];
128    }
129    
130    for (j=0;j<N;j++)
131    {
132       Rpp += p[j]*p[j];
133       Rxp += x[j]*p[j];
134    }
135    
136    /* We only need to initialise the zero because the first iteration only uses that */
137    for (i=0;i<N;i++)
138       y[0][i] = 0;
139    for (i=0;i<N;i++)
140       iy[0][i] = 0;
141    xy[0] = yy[0] = yp[0] = 0;
142
143    pulsesLeft = K;
144    while (pulsesLeft > 0)
145    {
146       int pulsesAtOnce=1;
147       int Lupdate = L;
148       int L2 = L;
149       
150       /* Decide on complexity strategy */
151       pulsesAtOnce = pulsesLeft/N;
152       if (pulsesAtOnce<1)
153          pulsesAtOnce = 1;
154       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
155          Lupdate = 1;
156       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
157       L2 = Lupdate;
158       if (L2>maxL)
159       {
160          L2 = maxL;
161          maxL *= N;
162       }
163
164       for (m=0;m<Lupdate;m++)
165          nbest[m]->score = -1e10f;
166
167       for (m=0;m<L2;m++)
168       {
169          for (j=0;j<N;j++)
170          {
171             int sign;
172             /*if (x[j]>0) sign=1; else sign=-1;*/
173             for (sign=-1;sign<=1;sign+=2)
174             {
175                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
176                float tmp_xy, tmp_yy, tmp_yp;
177                float score;
178                float g;
179                float s = sign*pulsesAtOnce;
180                
181                /* All pulses at one location must have the same sign. */
182                if (iy[m][j]*sign < 0)
183                   continue;
184
185                /* Updating the sums of the new pulse(s) */
186                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
187                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
188                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
189                
190                /* Compute the gain such that ||p + g*y|| = 1 */
191                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
192                /* Knowing that gain, what the error: (x-g*y)^2 
193                   (result is negated and we discard x^2 because it's constant) */
194                score = 2.f*g*tmp_xy - g*g*tmp_yy;
195
196                if (score>nbest[Lupdate-1]->score)
197                {
198                   int k;
199                   int id = Lupdate-1;
200                   struct NBest *tmp_best;
201
202                   /* Save some pointers that would be deleted and use them for the current entry*/
203                   tmp_best = nbest[Lupdate-1];
204                   while (id > 0 && score > nbest[id-1]->score)
205                      id--;
206                
207                   for (k=Lupdate-1;k>id;k--)
208                      nbest[k] = nbest[k-1];
209
210                   nbest[id] = tmp_best;
211                   nbest[id]->score = score;
212                   nbest[id]->pos = j;
213                   nbest[id]->orig = m;
214                   nbest[id]->sign = sign;
215                   nbest[id]->gain = g;
216                   nbest[id]->xy = tmp_xy;
217                   nbest[id]->yy = tmp_yy;
218                   nbest[id]->yp = tmp_yp;
219                }
220             }
221          }
222
223       }
224       /* Only now that we've made the final choice, update ny/iny and others */
225       for (k=0;k<Lupdate;k++)
226       {
227          int n;
228          int is;
229          float s;
230          is = nbest[k]->sign*pulsesAtOnce;
231          s = is;
232          for (n=0;n<N;n++)
233             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
234          ny[k][nbest[k]->pos] += s;
235
236          for (n=0;n<N;n++)
237             iny[k][n] = iy[nbest[k]->orig][n];
238          iny[k][nbest[k]->pos] += is;
239
240          xy[k] = nbest[k]->xy;
241          yy[k] = nbest[k]->yy;
242          yp[k] = nbest[k]->yp;
243       }
244       /* Swap ny/iny with y/iy */
245       for (k=0;k<Lupdate;k++)
246       {
247          float *tmp_ny;
248          int *tmp_iny;
249
250          tmp_ny = ny[k];
251          ny[k] = y[k];
252          y[k] = tmp_ny;
253          tmp_iny = iny[k];
254          iny[k] = iy[k];
255          iy[k] = tmp_iny;
256       }
257       pulsesLeft -= pulsesAtOnce;
258    }
259    
260    if (0) {
261       float err=0;
262       for (i=0;i<N;i++)
263          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
264       /*if (N<=10)
265         printf ("%f %d %d\n", err, K, N);*/
266    }
267    for (i=0;i<N;i++)
268       x[i] = p[i]+nbest[0]->gain*y[0][i];
269    /* Sanity checks, don't bother */
270    if (0) {
271       float E=1e-15;
272       int ABS = 0;
273       for (i=0;i<N;i++)
274          ABS += abs(iy[0][i]);
275       /*if (K != ABS)
276          printf ("%d %d\n", K, ABS);*/
277       for (i=0;i<N;i++)
278          E += x[i]*x[i];
279       /*printf ("%f\n", E);*/
280       E = 1/sqrt(E);
281       for (i=0;i<N;i++)
282          x[i] *= E;
283    }
284    
285    encode_pulses(iy[0], N, K, enc);
286    
287    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
288       due to the recursive computation used in quantisation.
289       Not quite sure whether we need that or not */
290    if (1) {
291       float Ryp=0;
292       float Ryy=0;
293       float g=0;
294       
295       for (i=0;i<N;i++)
296          Ryp += iy[0][i]*p[i];
297       
298       for (i=0;i<N;i++)
299          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
300       
301       Ryp = 0;
302       for (i=0;i<N;i++)
303          Ryp += y[0][i]*p[i];
304       
305       for (i=0;i<N;i++)
306          Ryy += y[0][i]*y[0][i];
307       
308       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
309         
310       for (i=0;i<N;i++)
311          x[i] = p[i] + g*y[0][i];
312       
313    }
314    for (j=0;j<N;j++)
315    {
316       X[j] = x[j] * NORM_SCALING;
317       P[j] = p[j] * NORM_SCALING;
318    }
319
320 }
321
322 /** Decode pulse vector and combine the result with the pitch vector to produce
323     the final normalised signal in the current band. */
324 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, float alpha, ec_dec *dec)
325 {
326    int i;
327    float Rpp=0, Ryp=0, Ryy=0;
328    float g;
329    VARDECL(int *iy);
330    VARDECL(float *y);
331    VARDECL(float *x);
332    VARDECL(float *p);
333    
334    ALLOC(iy, N, int);
335    ALLOC(y, N, float);
336    ALLOC(x, N, float);
337    ALLOC(p, N, float);
338
339    decode_pulses(iy, N, K, dec);
340    for (i=0;i<N;i++)
341    {
342       x[i] = X[i]*NORM_SCALING_1;
343       p[i] = P[i]*NORM_SCALING_1;
344    }
345
346    /*for (i=0;i<N;i++)
347       printf ("%d ", iy[i]);*/
348    for (i=0;i<N;i++)
349       Rpp += p[i]*p[i];
350
351    for (i=0;i<N;i++)
352       Ryp += iy[i]*p[i];
353
354    for (i=0;i<N;i++)
355       y[i] = iy[i] - alpha*Ryp*p[i];
356
357    /* Recompute after the projection (I think it's right) */
358    Ryp = 0;
359    for (i=0;i<N;i++)
360       Ryp += y[i]*p[i];
361
362    for (i=0;i<N;i++)
363       Ryy += y[i]*y[i];
364
365    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
366
367    for (i=0;i<N;i++)
368       x[i] = p[i] + g*y[i];
369    for (i=0;i<N;i++)
370    {
371       X[i] = x[i] * NORM_SCALING;
372       P[i] = p[i] * NORM_SCALING;
373    }
374
375 }
376
377
378 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
379
380 void intra_prediction(celt_norm_t *x, float *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
381 {
382    int i,j;
383    int best=0;
384    float best_score=0;
385    float s = 1;
386    int sign;
387    float E;
388    float pred_gain;
389    int max_pos = N0-N/B;
390    if (max_pos > 32)
391       max_pos = 32;
392
393    for (i=0;i<max_pos*B;i+=B)
394    {
395       int j;
396       float xy=0, yy=0;
397       float score;
398       for (j=0;j<N;j++)
399       {
400          xy += 1.f*x[j]*Y[i+N-j-1];
401          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
402       }
403       score = xy*xy/(.001+yy);
404       if (score > best_score)
405       {
406          best_score = score;
407          best = i;
408          if (xy>0)
409             s = 1;
410          else
411             s = -1;
412       }
413    }
414    if (s<0)
415       sign = 1;
416    else
417       sign = 0;
418    /*printf ("%d %d ", sign, best);*/
419    ec_enc_uint(enc,sign,2);
420    ec_enc_uint(enc,best/B,max_pos);
421    /*printf ("%d %f\n", best, best_score);*/
422    
423    if (K>10)
424       pred_gain = pg[10];
425    else
426       pred_gain = pg[K];
427    E = 1e-10;
428    for (j=0;j<N;j++)
429    {
430       P[j] = s*Y[best+N-j-1];
431       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
432    }
433    E = pred_gain/sqrt(E);
434    for (j=0;j<N;j++)
435       P[j] *= E;
436    if (K>0)
437    {
438       for (j=0;j<N;j++)
439          x[j] -= P[j];
440    } else {
441       for (j=0;j<N;j++)
442          x[j] = P[j];
443    }
444    /*printf ("quant ");*/
445    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
446
447 }
448
449 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
450 {
451    int j;
452    int sign;
453    float s;
454    int best;
455    float E;
456    float pred_gain;
457    int max_pos = N0-N/B;
458    if (max_pos > 32)
459       max_pos = 32;
460    
461    sign = ec_dec_uint(dec, 2);
462    if (sign == 0)
463       s = 1;
464    else
465       s = -1;
466    
467    best = B*ec_dec_uint(dec, max_pos);
468    /*printf ("%d %d ", sign, best);*/
469
470    if (K>10)
471       pred_gain = pg[10];
472    else
473       pred_gain = pg[K];
474    E = 1e-10;
475    for (j=0;j<N;j++)
476    {
477       P[j] = s*Y[best+N-j-1];
478       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
479    }
480    E = pred_gain/sqrt(E);
481    for (j=0;j<N;j++)
482       P[j] *= E;
483    if (K==0)
484    {
485       for (j=0;j<N;j++)
486          x[j] = P[j];
487    }
488 }
489
490 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
491 {
492    int i, j;
493    float E;
494    
495    E = 1e-10;
496    if (N0 >= Nmax/2)
497    {
498       for (i=0;i<B;i++)
499       {
500          for (j=0;j<N/B;j++)
501          {
502             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
503             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
504          }
505       }
506    } else {
507       for (j=0;j<N;j++)
508       {
509          P[j] = Y[j];
510          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
511       }
512    }
513    E = 1.f/sqrt(E);
514    for (j=0;j<N;j++)
515       P[j] *= E;
516    for (j=0;j<N;j++)
517       x[j] = P[j];
518 }
519