removed // comments and added stack_alloc.h (not used everywhere yet)
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37 /* Enable this or define your own implementation if you want to speed up the
38    VQ search (used in inner loop only) */
39 #if 0
40 #include <xmmintrin.h>
41 static inline float approx_sqrt(float x)
42 {
43    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
44    return x;
45 }
46 static inline float approx_inv(float x)
47 {
48    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
49    return x;
50 }
51 #else
52 #define approx_sqrt(x) (sqrt(x))
53 #define approx_inv(x) (1.f/(x))
54 #endif
55
56 struct NBest {
57    float score;
58    float gain;
59    int sign;
60    int pos;
61    int orig;
62    float xy;
63    float yy;
64    float yp;
65 };
66
67 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
68    a combination of pulses such that its norm is still equal to 1. The only difference with 
69    the quantiser above is that the search is more complete. */
70 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
71 {
72    int L = 3;
73    float _y[L][N];
74    int _iy[L][N];
75    float _ny[L][N];
76    int _iny[L][N];
77    float *(ny[L]), *(y[L]);
78    int *(iny[L]), *(iy[L]);
79    int i, j, k, m;
80    int pulsesLeft;
81    float xy[L];
82    float yy[L];
83    float yp[L];
84    struct NBest _nbest[L];
85    struct NBest *(nbest[L]);
86    float Rpp=0, Rxp=0;
87    int maxL = 1;
88    
89    for (m=0;m<L;m++)
90       nbest[m] = &_nbest[m];
91    
92    for (m=0;m<L;m++)
93    {
94       ny[m] = _ny[m];
95       iny[m] = _iny[m];
96       y[m] = _y[m];
97       iy[m] = _iy[m];
98    }
99    
100    for (j=0;j<N;j++)
101    {
102       Rpp += p[j]*p[j];
103       Rxp += x[j]*p[j];
104    }
105    
106    /* We only need to initialise the zero because the first iteration only uses that */
107    for (i=0;i<N;i++)
108       y[0][i] = 0;
109    for (i=0;i<N;i++)
110       iy[0][i] = 0;
111    xy[0] = yy[0] = yp[0] = 0;
112
113    pulsesLeft = K;
114    while (pulsesLeft > 0)
115    {
116       int pulsesAtOnce=1;
117       int Lupdate = L;
118       int L2 = L;
119       
120       /* Decide on complexity strategy */
121       pulsesAtOnce = pulsesLeft/N;
122       if (pulsesAtOnce<1)
123          pulsesAtOnce = 1;
124       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
125          Lupdate = 1;
126       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
127       L2 = Lupdate;
128       if (L2>maxL)
129       {
130          L2 = maxL;
131          maxL *= N;
132       }
133
134       for (m=0;m<Lupdate;m++)
135          nbest[m]->score = -1e10f;
136
137       for (m=0;m<L2;m++)
138       {
139          for (j=0;j<N;j++)
140          {
141             int sign;
142             /*if (x[j]>0) sign=1; else sign=-1;*/
143             for (sign=-1;sign<=1;sign+=2)
144             {
145                /* All pulses at one location must have the same sign. */
146                if (iy[m][j]*sign < 0)
147                   continue;
148                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
149                float tmp_xy, tmp_yy, tmp_yp;
150                float score;
151                float g;
152                float s = sign*pulsesAtOnce;
153                
154                /* Updating the sums of the new pulse(s) */
155                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
156                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
157                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
158                
159                /* Compute the gain such that ||p + g*y|| = 1 */
160                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
161                /* Knowing that gain, what the error: (x-g*y)^2 
162                   (result is negated and we discard x^2 because it's constant) */
163                score = 2.f*g*tmp_xy - g*g*tmp_yy;
164
165                if (score>nbest[Lupdate-1]->score)
166                {
167                   int k;
168                   int id = Lupdate-1;
169                   struct NBest *tmp_best;
170
171                   /* Save some pointers that would be deleted and use them for the current entry*/
172                   tmp_best = nbest[Lupdate-1];
173                   while (id > 0 && score > nbest[id-1]->score)
174                      id--;
175                
176                   for (k=Lupdate-1;k>id;k--)
177                      nbest[k] = nbest[k-1];
178
179                   nbest[id] = tmp_best;
180                   nbest[id]->score = score;
181                   nbest[id]->pos = j;
182                   nbest[id]->orig = m;
183                   nbest[id]->sign = sign;
184                   nbest[id]->gain = g;
185                   nbest[id]->xy = tmp_xy;
186                   nbest[id]->yy = tmp_yy;
187                   nbest[id]->yp = tmp_yp;
188                }
189             }
190          }
191
192       }
193       /* Only now that we've made the final choice, update ny/iny and others */
194       for (k=0;k<Lupdate;k++)
195       {
196          int n;
197          int is;
198          float s;
199          is = nbest[k]->sign*pulsesAtOnce;
200          s = is;
201          for (n=0;n<N;n++)
202             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
203          ny[k][nbest[k]->pos] += s;
204
205          for (n=0;n<N;n++)
206             iny[k][n] = iy[nbest[k]->orig][n];
207          iny[k][nbest[k]->pos] += is;
208
209          xy[k] = nbest[k]->xy;
210          yy[k] = nbest[k]->yy;
211          yp[k] = nbest[k]->yp;
212       }
213       /* Swap ny/iny with y/iy */
214       for (k=0;k<Lupdate;k++)
215       {
216          float *tmp_ny;
217          int *tmp_iny;
218
219          tmp_ny = ny[k];
220          ny[k] = y[k];
221          y[k] = tmp_ny;
222          tmp_iny = iny[k];
223          iny[k] = iy[k];
224          iy[k] = tmp_iny;
225       }
226       pulsesLeft -= pulsesAtOnce;
227    }
228    
229    if (0) {
230       float err=0;
231       for (i=0;i<N;i++)
232          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
233       /*if (N<=10)
234         printf ("%f %d %d\n", err, K, N);*/
235    }
236    for (i=0;i<N;i++)
237       x[i] = p[i]+nbest[0]->gain*y[0][i];
238    /* Sanity checks, don't bother */
239    if (0) {
240       float E=1e-15;
241       int ABS = 0;
242       for (i=0;i<N;i++)
243          ABS += abs(iy[0][i]);
244       /*if (K != ABS)
245          printf ("%d %d\n", K, ABS);*/
246       for (i=0;i<N;i++)
247          E += x[i]*x[i];
248       /*printf ("%f\n", E);*/
249       E = 1/sqrt(E);
250       for (i=0;i<N;i++)
251          x[i] *= E;
252    }
253    
254    encode_pulses(iy[0], N, K, enc);
255    
256    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
257       due to the recursive computation used in quantisation.
258       Not quite sure whether we need that or not */
259    if (1) {
260       float Ryp=0;
261       float Ryy=0;
262       float g=0;
263       
264       for (i=0;i<N;i++)
265          Ryp += iy[0][i]*p[i];
266       
267       for (i=0;i<N;i++)
268          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
269       
270       Ryp = 0;
271       for (i=0;i<N;i++)
272          Ryp += y[0][i]*p[i];
273       
274       for (i=0;i<N;i++)
275          Ryy += y[0][i]*y[0][i];
276       
277       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
278         
279       for (i=0;i<N;i++)
280          x[i] = p[i] + g*y[0][i];
281       
282    }
283
284 }
285
286 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
287 {
288    int i;
289    int iy[N];
290    float y[N];
291    float Rpp=0, Ryp=0, Ryy=0;
292    float g;
293
294    decode_pulses(iy, N, K, dec);
295
296    /*for (i=0;i<N;i++)
297       printf ("%d ", iy[i]);*/
298    for (i=0;i<N;i++)
299       Rpp += p[i]*p[i];
300
301    for (i=0;i<N;i++)
302       Ryp += iy[i]*p[i];
303
304    for (i=0;i<N;i++)
305       y[i] = iy[i] - alpha*Ryp*p[i];
306
307    /* Recompute after the projection (I think it's right) */
308    Ryp = 0;
309    for (i=0;i<N;i++)
310       Ryp += y[i]*p[i];
311
312    for (i=0;i<N;i++)
313       Ryy += y[i]*y[i];
314
315    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
316
317    for (i=0;i<N;i++)
318       x[i] = p[i] + g*y[i];
319 }
320
321
322 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
323
324 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
325 {
326    int i,j;
327    int best=0;
328    float best_score=0;
329    float s = 1;
330    int sign;
331    float E;
332    int max_pos = N0-N/B;
333    if (max_pos > 32)
334       max_pos = 32;
335
336    for (i=0;i<max_pos*B;i+=B)
337    {
338       int j;
339       float xy=0, yy=0;
340       float score;
341       for (j=0;j<N;j++)
342       {
343          xy += x[j]*Y[i+j];
344          yy += Y[i+j]*Y[i+j];
345       }
346       score = xy*xy/(.001+yy);
347       if (score > best_score)
348       {
349          best_score = score;
350          best = i;
351          if (xy>0)
352             s = 1;
353          else
354             s = -1;
355       }
356    }
357    if (s<0)
358       sign = 1;
359    else
360       sign = 0;
361    /*printf ("%d %d ", sign, best);*/
362    ec_enc_uint(enc,sign,2);
363    ec_enc_uint(enc,best/B,max_pos);
364    /*printf ("%d %f\n", best, best_score);*/
365    
366    float pred_gain;
367    if (K>10)
368       pred_gain = pg[10];
369    else
370       pred_gain = pg[K];
371    E = 1e-10;
372    for (j=0;j<N;j++)
373    {
374       P[j] = s*Y[best+j];
375       E += P[j]*P[j];
376    }
377    E = pred_gain/sqrt(E);
378    for (j=0;j<N;j++)
379       P[j] *= E;
380    if (K>0)
381    {
382       for (j=0;j<N;j++)
383          x[j] -= P[j];
384    } else {
385       for (j=0;j<N;j++)
386          x[j] = P[j];
387    }
388    /*printf ("quant ");*/
389    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
390
391 }
392
393 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
394 {
395    int j;
396    int sign;
397    float s;
398    int best;
399    float E;
400    int max_pos = N0-N/B;
401    if (max_pos > 32)
402       max_pos = 32;
403    
404    sign = ec_dec_uint(dec, 2);
405    if (sign == 0)
406       s = 1;
407    else
408       s = -1;
409    
410    best = B*ec_dec_uint(dec, max_pos);
411    /*printf ("%d %d ", sign, best);*/
412
413    float pred_gain;
414    if (K>10)
415       pred_gain = pg[10];
416    else
417       pred_gain = pg[K];
418    E = 1e-10;
419    for (j=0;j<N;j++)
420    {
421       P[j] = s*Y[best+j];
422       E += P[j]*P[j];
423    }
424    E = pred_gain/sqrt(E);
425    for (j=0;j<N;j++)
426       P[j] *= E;
427    if (K==0)
428    {
429       for (j=0;j<N;j++)
430          x[j] = P[j];
431    }
432 }
433
434 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
435 {
436    int i, j;
437    float E;
438    
439    E = 1e-10;
440    if (N0 >= Nmax/2)
441    {
442       for (i=0;i<B;i++)
443       {
444          for (j=0;j<N/B;j++)
445          {
446             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
447             E += P[j*B+i]*P[j*B+i];
448          }
449       }
450    } else {
451       for (j=0;j<N;j++)
452       {
453          P[j] = Y[j];
454          E += P[j]*P[j];
455       }
456    }
457    E = 1.f/sqrt(E);
458    for (j=0;j<N;j++)
459       P[j] *= E;
460    for (j=0;j<N;j++)
461       x[j] = P[j];
462 }
463