fixed-point: Using a NORM_SCALING of 16384, sig_norm_t is still a float though.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41
42 /* Enable this or define your own implementation if you want to speed up the
43    VQ search (used in inner loop only) */
44 #if 0
45 #include <xmmintrin.h>
46 static inline float approx_sqrt(float x)
47 {
48    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
49    return x;
50 }
51 static inline float approx_inv(float x)
52 {
53    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
54    return x;
55 }
56 #else
57 #define approx_sqrt(x) (sqrt(x))
58 #define approx_inv(x) (1.f/(x))
59 #endif
60
61 /** All the info necessary to keep track of a hypothesis during the search */
62 struct NBest {
63    float score;
64    float gain;
65    int sign;
66    int pos;
67    int orig;
68    float xy;
69    float yy;
70    float yp;
71 };
72
73 void alg_quant(celt_norm_t *x, float *W, int N, int K, celt_norm_t *p, float alpha, ec_enc *enc)
74 {
75    int L = 3;
76    VARDECL(float *_y);
77    VARDECL(float *_ny);
78    VARDECL(int *_iy);
79    VARDECL(int *_iny);
80    VARDECL(float **y);
81    VARDECL(float **ny);
82    VARDECL(int **iy);
83    VARDECL(int **iny);
84    int i, j, k, m;
85    int pulsesLeft;
86    VARDECL(float *xy);
87    VARDECL(float *yy);
88    VARDECL(float *yp);
89    VARDECL(struct NBest *_nbest);
90    VARDECL(struct NBest **nbest);
91    float Rpp=0, Rxp=0;
92    int maxL = 1;
93    
94    ALLOC(_y, L*N, float);
95    ALLOC(_ny, L*N, float);
96    ALLOC(_iy, L*N, int);
97    ALLOC(_iny, L*N, int);
98    ALLOC(y, L*N, float*);
99    ALLOC(ny, L*N, float*);
100    ALLOC(iy, L*N, int*);
101    ALLOC(iny, L*N, int*);
102    
103    ALLOC(xy, L, float);
104    ALLOC(yy, L, float);
105    ALLOC(yp, L, float);
106    ALLOC(_nbest, L, struct NBest);
107    ALLOC(nbest, L, struct NBest *);
108
109    for (j=0;j<N;j++)
110    {
111       x[j] *= NORM_SCALING_1;
112       p[j] *= NORM_SCALING_1;
113    }
114    
115    for (m=0;m<L;m++)
116       nbest[m] = &_nbest[m];
117    
118    for (m=0;m<L;m++)
119    {
120       ny[m] = &_ny[m*N];
121       iny[m] = &_iny[m*N];
122       y[m] = &_y[m*N];
123       iy[m] = &_iy[m*N];
124    }
125    
126    for (j=0;j<N;j++)
127    {
128       Rpp += p[j]*p[j];
129       Rxp += x[j]*p[j];
130    }
131    
132    /* We only need to initialise the zero because the first iteration only uses that */
133    for (i=0;i<N;i++)
134       y[0][i] = 0;
135    for (i=0;i<N;i++)
136       iy[0][i] = 0;
137    xy[0] = yy[0] = yp[0] = 0;
138
139    pulsesLeft = K;
140    while (pulsesLeft > 0)
141    {
142       int pulsesAtOnce=1;
143       int Lupdate = L;
144       int L2 = L;
145       
146       /* Decide on complexity strategy */
147       pulsesAtOnce = pulsesLeft/N;
148       if (pulsesAtOnce<1)
149          pulsesAtOnce = 1;
150       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
151          Lupdate = 1;
152       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
153       L2 = Lupdate;
154       if (L2>maxL)
155       {
156          L2 = maxL;
157          maxL *= N;
158       }
159
160       for (m=0;m<Lupdate;m++)
161          nbest[m]->score = -1e10f;
162
163       for (m=0;m<L2;m++)
164       {
165          for (j=0;j<N;j++)
166          {
167             int sign;
168             /*if (x[j]>0) sign=1; else sign=-1;*/
169             for (sign=-1;sign<=1;sign+=2)
170             {
171                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
172                float tmp_xy, tmp_yy, tmp_yp;
173                float score;
174                float g;
175                float s = sign*pulsesAtOnce;
176                
177                /* All pulses at one location must have the same sign. */
178                if (iy[m][j]*sign < 0)
179                   continue;
180
181                /* Updating the sums of the new pulse(s) */
182                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
183                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
184                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
185                
186                /* Compute the gain such that ||p + g*y|| = 1 */
187                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
188                /* Knowing that gain, what the error: (x-g*y)^2 
189                   (result is negated and we discard x^2 because it's constant) */
190                score = 2.f*g*tmp_xy - g*g*tmp_yy;
191
192                if (score>nbest[Lupdate-1]->score)
193                {
194                   int k;
195                   int id = Lupdate-1;
196                   struct NBest *tmp_best;
197
198                   /* Save some pointers that would be deleted and use them for the current entry*/
199                   tmp_best = nbest[Lupdate-1];
200                   while (id > 0 && score > nbest[id-1]->score)
201                      id--;
202                
203                   for (k=Lupdate-1;k>id;k--)
204                      nbest[k] = nbest[k-1];
205
206                   nbest[id] = tmp_best;
207                   nbest[id]->score = score;
208                   nbest[id]->pos = j;
209                   nbest[id]->orig = m;
210                   nbest[id]->sign = sign;
211                   nbest[id]->gain = g;
212                   nbest[id]->xy = tmp_xy;
213                   nbest[id]->yy = tmp_yy;
214                   nbest[id]->yp = tmp_yp;
215                }
216             }
217          }
218
219       }
220       /* Only now that we've made the final choice, update ny/iny and others */
221       for (k=0;k<Lupdate;k++)
222       {
223          int n;
224          int is;
225          float s;
226          is = nbest[k]->sign*pulsesAtOnce;
227          s = is;
228          for (n=0;n<N;n++)
229             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
230          ny[k][nbest[k]->pos] += s;
231
232          for (n=0;n<N;n++)
233             iny[k][n] = iy[nbest[k]->orig][n];
234          iny[k][nbest[k]->pos] += is;
235
236          xy[k] = nbest[k]->xy;
237          yy[k] = nbest[k]->yy;
238          yp[k] = nbest[k]->yp;
239       }
240       /* Swap ny/iny with y/iy */
241       for (k=0;k<Lupdate;k++)
242       {
243          float *tmp_ny;
244          int *tmp_iny;
245
246          tmp_ny = ny[k];
247          ny[k] = y[k];
248          y[k] = tmp_ny;
249          tmp_iny = iny[k];
250          iny[k] = iy[k];
251          iy[k] = tmp_iny;
252       }
253       pulsesLeft -= pulsesAtOnce;
254    }
255    
256    if (0) {
257       float err=0;
258       for (i=0;i<N;i++)
259          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
260       /*if (N<=10)
261         printf ("%f %d %d\n", err, K, N);*/
262    }
263    for (i=0;i<N;i++)
264       x[i] = p[i]+nbest[0]->gain*y[0][i];
265    /* Sanity checks, don't bother */
266    if (0) {
267       float E=1e-15;
268       int ABS = 0;
269       for (i=0;i<N;i++)
270          ABS += abs(iy[0][i]);
271       /*if (K != ABS)
272          printf ("%d %d\n", K, ABS);*/
273       for (i=0;i<N;i++)
274          E += x[i]*x[i];
275       /*printf ("%f\n", E);*/
276       E = 1/sqrt(E);
277       for (i=0;i<N;i++)
278          x[i] *= E;
279    }
280    
281    encode_pulses(iy[0], N, K, enc);
282    
283    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
284       due to the recursive computation used in quantisation.
285       Not quite sure whether we need that or not */
286    if (1) {
287       float Ryp=0;
288       float Ryy=0;
289       float g=0;
290       
291       for (i=0;i<N;i++)
292          Ryp += iy[0][i]*p[i];
293       
294       for (i=0;i<N;i++)
295          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
296       
297       Ryp = 0;
298       for (i=0;i<N;i++)
299          Ryp += y[0][i]*p[i];
300       
301       for (i=0;i<N;i++)
302          Ryy += y[0][i]*y[0][i];
303       
304       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
305         
306       for (i=0;i<N;i++)
307          x[i] = p[i] + g*y[0][i];
308       
309    }
310    for (j=0;j<N;j++)
311    {
312       x[j] *= NORM_SCALING;
313       p[j] *= NORM_SCALING;
314    }
315
316 }
317
318 /** Decode pulse vector and combine the result with the pitch vector to produce
319     the final normalised signal in the current band. */
320 void alg_unquant(celt_norm_t *x, int N, int K, celt_norm_t *p, float alpha, ec_dec *dec)
321 {
322    int i;
323    float Rpp=0, Ryp=0, Ryy=0;
324    float g;
325    VARDECL(int *iy);
326    VARDECL(float *y);
327    
328    ALLOC(iy, N, int);
329    ALLOC(y, N, float);
330
331    decode_pulses(iy, N, K, dec);
332    for (i=0;i<N;i++)
333    {
334       x[i] *= NORM_SCALING_1;
335       p[i] *= NORM_SCALING_1;
336    }
337
338    /*for (i=0;i<N;i++)
339       printf ("%d ", iy[i]);*/
340    for (i=0;i<N;i++)
341       Rpp += p[i]*p[i];
342
343    for (i=0;i<N;i++)
344       Ryp += iy[i]*p[i];
345
346    for (i=0;i<N;i++)
347       y[i] = iy[i] - alpha*Ryp*p[i];
348
349    /* Recompute after the projection (I think it's right) */
350    Ryp = 0;
351    for (i=0;i<N;i++)
352       Ryp += y[i]*p[i];
353
354    for (i=0;i<N;i++)
355       Ryy += y[i]*y[i];
356
357    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
358
359    for (i=0;i<N;i++)
360       x[i] = p[i] + g*y[i];
361    for (i=0;i<N;i++)
362    {
363       x[i] *= NORM_SCALING;
364       p[i] *= NORM_SCALING;
365    }
366
367 }
368
369
370 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
371
372 void intra_prediction(celt_norm_t *x, float *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
373 {
374    int i,j;
375    int best=0;
376    float best_score=0;
377    float s = 1;
378    int sign;
379    float E;
380    float pred_gain;
381    int max_pos = N0-N/B;
382    if (max_pos > 32)
383       max_pos = 32;
384
385    for (i=0;i<max_pos*B;i+=B)
386    {
387       int j;
388       float xy=0, yy=0;
389       float score;
390       for (j=0;j<N;j++)
391       {
392          xy += 1.f*x[j]*Y[i+N-j-1];
393          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
394       }
395       score = xy*xy/(.001+yy);
396       if (score > best_score)
397       {
398          best_score = score;
399          best = i;
400          if (xy>0)
401             s = 1;
402          else
403             s = -1;
404       }
405    }
406    if (s<0)
407       sign = 1;
408    else
409       sign = 0;
410    /*printf ("%d %d ", sign, best);*/
411    ec_enc_uint(enc,sign,2);
412    ec_enc_uint(enc,best/B,max_pos);
413    /*printf ("%d %f\n", best, best_score);*/
414    
415    if (K>10)
416       pred_gain = pg[10];
417    else
418       pred_gain = pg[K];
419    E = 1e-10;
420    for (j=0;j<N;j++)
421    {
422       P[j] = s*Y[best+N-j-1];
423       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
424    }
425    E = pred_gain/sqrt(E);
426    for (j=0;j<N;j++)
427       P[j] *= E;
428    if (K>0)
429    {
430       for (j=0;j<N;j++)
431          x[j] -= P[j];
432    } else {
433       for (j=0;j<N;j++)
434          x[j] = P[j];
435    }
436    /*printf ("quant ");*/
437    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
438
439 }
440
441 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
442 {
443    int j;
444    int sign;
445    float s;
446    int best;
447    float E;
448    float pred_gain;
449    int max_pos = N0-N/B;
450    if (max_pos > 32)
451       max_pos = 32;
452    
453    sign = ec_dec_uint(dec, 2);
454    if (sign == 0)
455       s = 1;
456    else
457       s = -1;
458    
459    best = B*ec_dec_uint(dec, max_pos);
460    /*printf ("%d %d ", sign, best);*/
461
462    if (K>10)
463       pred_gain = pg[10];
464    else
465       pred_gain = pg[K];
466    E = 1e-10;
467    for (j=0;j<N;j++)
468    {
469       P[j] = s*Y[best+N-j-1];
470       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
471    }
472    E = pred_gain/sqrt(E);
473    for (j=0;j<N;j++)
474       P[j] *= E;
475    if (K==0)
476    {
477       for (j=0;j<N;j++)
478          x[j] = P[j];
479    }
480 }
481
482 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
483 {
484    int i, j;
485    float E;
486    
487    E = 1e-10;
488    if (N0 >= Nmax/2)
489    {
490       for (i=0;i<B;i++)
491       {
492          for (j=0;j<N/B;j++)
493          {
494             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
495             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
496          }
497       }
498    } else {
499       for (j=0;j<N;j++)
500       {
501          P[j] = Y[j];
502          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
503       }
504    }
505    E = 1.f/sqrt(E);
506    for (j=0;j<N;j++)
507       P[j] *= E;
508    for (j=0;j<N;j++)
509       x[j] = P[j];
510 }
511