fixed-point: converted pitch_quant_bands() -- that one was an easy one-liner
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42
43 /* Enable this or define your own implementation if you want to speed up the
44    VQ search (used in inner loop only) */
45 #if 0
46 #include <xmmintrin.h>
47 static inline float approx_sqrt(float x)
48 {
49    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
50    return x;
51 }
52 static inline float approx_inv(float x)
53 {
54    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
55    return x;
56 }
57 #else
58 #define approx_sqrt(x) (sqrt(x))
59 #define approx_inv(x) (1.f/(x))
60 #endif
61
62 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
63    applies the compression in the pitch direction, computes the gain that will
64    give ||p+g*y||=1 and mixes the residual with the pitch. */
65 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, celt_norm_t *P, float alpha)
66 {
67    int i;
68    float Rpp=0, Ryp=0, Ryy=0;
69    float g;
70    VARDECL(float *y);
71    VARDECL(float *x);
72    VARDECL(float *p);
73    
74    ALLOC(y, N, float);
75    ALLOC(x, N, float);
76    ALLOC(p, N, float);
77
78    for (i=0;i<N;i++)
79       p[i] = P[i]*NORM_SCALING_1;
80
81    /*for (i=0;i<N;i++)
82    printf ("%d ", iy[i]);*/
83    for (i=0;i<N;i++)
84       Rpp += p[i]*p[i];
85
86    for (i=0;i<N;i++)
87       Ryp += iy[i]*p[i];
88
89    for (i=0;i<N;i++)
90       y[i] = iy[i] - alpha*Ryp*p[i];
91
92    /* Recompute after the projection (I think it's right) */
93    Ryp = 0;
94    for (i=0;i<N;i++)
95       Ryp += y[i]*p[i];
96
97    for (i=0;i<N;i++)
98       Ryy += y[i]*y[i];
99
100    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
101
102    for (i=0;i<N;i++)
103       X[i] = P[i] + NORM_SCALING*g*y[i];
104 }
105
106 /** All the info necessary to keep track of a hypothesis during the search */
107 struct NBest {
108    float score;
109    float gain;
110    int sign;
111    int pos;
112    int orig;
113    float xy;
114    float yy;
115    float yp;
116 };
117
118 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, float alpha, ec_enc *enc)
119 {
120    int L = 3;
121    VARDECL(float *x);
122    VARDECL(float *p);
123    VARDECL(float *_y);
124    VARDECL(float *_ny);
125    VARDECL(int *_iy);
126    VARDECL(int *_iny);
127    VARDECL(float **y);
128    VARDECL(float **ny);
129    VARDECL(int **iy);
130    VARDECL(int **iny);
131    int i, j, k, m;
132    int pulsesLeft;
133    VARDECL(float *xy);
134    VARDECL(float *yy);
135    VARDECL(float *yp);
136    VARDECL(struct NBest *_nbest);
137    VARDECL(struct NBest **nbest);
138    float Rpp=0, Rxp=0;
139    int maxL = 1;
140    
141    ALLOC(x, N, float);
142    ALLOC(p, N, float);
143    ALLOC(_y, L*N, float);
144    ALLOC(_ny, L*N, float);
145    ALLOC(_iy, L*N, int);
146    ALLOC(_iny, L*N, int);
147    ALLOC(y, L*N, float*);
148    ALLOC(ny, L*N, float*);
149    ALLOC(iy, L*N, int*);
150    ALLOC(iny, L*N, int*);
151    
152    ALLOC(xy, L, float);
153    ALLOC(yy, L, float);
154    ALLOC(yp, L, float);
155    ALLOC(_nbest, L, struct NBest);
156    ALLOC(nbest, L, struct NBest *);
157
158    for (j=0;j<N;j++)
159    {
160       x[j] = X[j]*NORM_SCALING_1;
161       p[j] = P[j]*NORM_SCALING_1;
162    }
163    
164    for (m=0;m<L;m++)
165       nbest[m] = &_nbest[m];
166    
167    for (m=0;m<L;m++)
168    {
169       ny[m] = &_ny[m*N];
170       iny[m] = &_iny[m*N];
171       y[m] = &_y[m*N];
172       iy[m] = &_iy[m*N];
173    }
174    
175    for (j=0;j<N;j++)
176    {
177       Rpp += p[j]*p[j];
178       Rxp += x[j]*p[j];
179    }
180    if (Rpp>1)
181       celt_fatal("Rpp > 1");
182
183    /* We only need to initialise the zero because the first iteration only uses that */
184    for (i=0;i<N;i++)
185       y[0][i] = 0;
186    for (i=0;i<N;i++)
187       iy[0][i] = 0;
188    xy[0] = yy[0] = yp[0] = 0;
189
190    pulsesLeft = K;
191    while (pulsesLeft > 0)
192    {
193       int pulsesAtOnce=1;
194       int Lupdate = L;
195       int L2 = L;
196       
197       /* Decide on complexity strategy */
198       pulsesAtOnce = pulsesLeft/N;
199       if (pulsesAtOnce<1)
200          pulsesAtOnce = 1;
201       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
202          Lupdate = 1;
203       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
204       L2 = Lupdate;
205       if (L2>maxL)
206       {
207          L2 = maxL;
208          maxL *= N;
209       }
210
211       for (m=0;m<Lupdate;m++)
212          nbest[m]->score = -1e10f;
213
214       for (m=0;m<L2;m++)
215       {
216          for (j=0;j<N;j++)
217          {
218             int sign;
219             /*if (x[j]>0) sign=1; else sign=-1;*/
220             for (sign=-1;sign<=1;sign+=2)
221             {
222                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
223                float tmp_xy, tmp_yy, tmp_yp;
224                float score;
225                float g;
226                float s = sign*pulsesAtOnce;
227                
228                /* All pulses at one location must have the same sign. */
229                if (iy[m][j]*sign < 0)
230                   continue;
231
232                /* Updating the sums of the new pulse(s) */
233                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
234                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
235                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
236                
237                /* Compute the gain such that ||p + g*y|| = 1 */
238                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
239                /* Knowing that gain, what the error: (x-g*y)^2 
240                   (result is negated and we discard x^2 because it's constant) */
241                score = 2.f*g*tmp_xy - g*g*tmp_yy;
242
243                if (score>nbest[Lupdate-1]->score)
244                {
245                   int k;
246                   int id = Lupdate-1;
247                   struct NBest *tmp_best;
248
249                   /* Save some pointers that would be deleted and use them for the current entry*/
250                   tmp_best = nbest[Lupdate-1];
251                   while (id > 0 && score > nbest[id-1]->score)
252                      id--;
253                
254                   for (k=Lupdate-1;k>id;k--)
255                      nbest[k] = nbest[k-1];
256
257                   nbest[id] = tmp_best;
258                   nbest[id]->score = score;
259                   nbest[id]->pos = j;
260                   nbest[id]->orig = m;
261                   nbest[id]->sign = sign;
262                   nbest[id]->gain = g;
263                   nbest[id]->xy = tmp_xy;
264                   nbest[id]->yy = tmp_yy;
265                   nbest[id]->yp = tmp_yp;
266                }
267             }
268          }
269
270       }
271       
272       if (!(nbest[0]->score > -1e10f))
273          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
274       /* Only now that we've made the final choice, update ny/iny and others */
275       for (k=0;k<Lupdate;k++)
276       {
277          int n;
278          int is;
279          float s;
280          is = nbest[k]->sign*pulsesAtOnce;
281          s = is;
282          for (n=0;n<N;n++)
283             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
284          ny[k][nbest[k]->pos] += s;
285
286          for (n=0;n<N;n++)
287             iny[k][n] = iy[nbest[k]->orig][n];
288          iny[k][nbest[k]->pos] += is;
289
290          xy[k] = nbest[k]->xy;
291          yy[k] = nbest[k]->yy;
292          yp[k] = nbest[k]->yp;
293       }
294       /* Swap ny/iny with y/iy */
295       for (k=0;k<Lupdate;k++)
296       {
297          float *tmp_ny;
298          int *tmp_iny;
299
300          tmp_ny = ny[k];
301          ny[k] = y[k];
302          y[k] = tmp_ny;
303          tmp_iny = iny[k];
304          iny[k] = iy[k];
305          iy[k] = tmp_iny;
306       }
307       pulsesLeft -= pulsesAtOnce;
308    }
309    
310 #if 0
311    if (0) {
312       float err=0;
313       for (i=0;i<N;i++)
314          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
315       /*if (N<=10)
316         printf ("%f %d %d\n", err, K, N);*/
317    }
318    /* Sanity checks, don't bother */
319    if (0) {
320       for (i=0;i<N;i++)
321          x[i] = p[i]+nbest[0]->gain*y[0][i];
322       float E=1e-15;
323       int ABS = 0;
324       for (i=0;i<N;i++)
325          ABS += abs(iy[0][i]);
326       /*if (K != ABS)
327          printf ("%d %d\n", K, ABS);*/
328       for (i=0;i<N;i++)
329          E += x[i]*x[i];
330       /*printf ("%f\n", E);*/
331       E = 1/sqrt(E);
332       for (i=0;i<N;i++)
333          x[i] *= E;
334    }
335 #endif
336    
337    encode_pulses(iy[0], N, K, enc);
338    
339    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
340       due to the recursive computation used in quantisation.
341       Not quite sure whether we need that or not */
342    mix_pitch_and_residual(iy[0], X, N, P, alpha);
343 }
344
345 /** Decode pulse vector and combine the result with the pitch vector to produce
346     the final normalised signal in the current band. */
347 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, float alpha, ec_dec *dec)
348 {
349    VARDECL(int *iy);
350    ALLOC(iy, N, int);
351    decode_pulses(iy, N, K, dec);
352    mix_pitch_and_residual(iy, X, N, P, alpha);
353 }
354
355
356 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
357
358 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
359 {
360    int i,j;
361    int best=0;
362    float best_score=0;
363    float s = 1;
364    int sign;
365    float E;
366    float pred_gain;
367    int max_pos = N0-N/B;
368    if (max_pos > 32)
369       max_pos = 32;
370
371    for (i=0;i<max_pos*B;i+=B)
372    {
373       int j;
374       float xy=0, yy=0;
375       float score;
376       for (j=0;j<N;j++)
377       {
378          xy += 1.f*x[j]*Y[i+N-j-1];
379          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
380       }
381       score = xy*xy/(.001+yy);
382       if (score > best_score)
383       {
384          best_score = score;
385          best = i;
386          if (xy>0)
387             s = 1;
388          else
389             s = -1;
390       }
391    }
392    if (s<0)
393       sign = 1;
394    else
395       sign = 0;
396    /*printf ("%d %d ", sign, best);*/
397    ec_enc_uint(enc,sign,2);
398    ec_enc_uint(enc,best/B,max_pos);
399    /*printf ("%d %f\n", best, best_score);*/
400    
401    if (K>10)
402       pred_gain = pg[10];
403    else
404       pred_gain = pg[K];
405    E = 1e-10;
406    for (j=0;j<N;j++)
407    {
408       P[j] = s*Y[best+N-j-1];
409       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
410    }
411    E = pred_gain/sqrt(E);
412    for (j=0;j<N;j++)
413       P[j] *= E;
414    if (K>0)
415    {
416       for (j=0;j<N;j++)
417          x[j] -= P[j];
418    } else {
419       for (j=0;j<N;j++)
420          x[j] = P[j];
421    }
422    /*printf ("quant ");*/
423    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
424
425 }
426
427 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
428 {
429    int j;
430    int sign;
431    float s;
432    int best;
433    float E;
434    float pred_gain;
435    int max_pos = N0-N/B;
436    if (max_pos > 32)
437       max_pos = 32;
438    
439    sign = ec_dec_uint(dec, 2);
440    if (sign == 0)
441       s = 1;
442    else
443       s = -1;
444    
445    best = B*ec_dec_uint(dec, max_pos);
446    /*printf ("%d %d ", sign, best);*/
447
448    if (K>10)
449       pred_gain = pg[10];
450    else
451       pred_gain = pg[K];
452    E = 1e-10;
453    for (j=0;j<N;j++)
454    {
455       P[j] = s*Y[best+N-j-1];
456       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
457    }
458    E = pred_gain/sqrt(E);
459    for (j=0;j<N;j++)
460       P[j] *= E;
461    if (K==0)
462    {
463       for (j=0;j<N;j++)
464          x[j] = P[j];
465    }
466 }
467
468 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
469 {
470    int i, j;
471    float E;
472    
473    E = 1e-10;
474    if (N0 >= Nmax/2)
475    {
476       for (i=0;i<B;i++)
477       {
478          for (j=0;j<N/B;j++)
479          {
480             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
481             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
482          }
483       }
484    } else {
485       for (j=0;j<N;j++)
486       {
487          P[j] = Y[j];
488          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
489       }
490    }
491    E = 1.f/sqrt(E);
492    for (j=0;j<N;j++)
493       P[j] *= E;
494    for (j=0;j<N;j++)
495       x[j] = P[j];
496 }
497