More C89 fixes, making sure to include config.h from all source files.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40
41 /* Enable this or define your own implementation if you want to speed up the
42    VQ search (used in inner loop only) */
43 #if 0
44 #include <xmmintrin.h>
45 static inline float approx_sqrt(float x)
46 {
47    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
48    return x;
49 }
50 static inline float approx_inv(float x)
51 {
52    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
53    return x;
54 }
55 #else
56 #define approx_sqrt(x) (sqrt(x))
57 #define approx_inv(x) (1.f/(x))
58 #endif
59
60 struct NBest {
61    float score;
62    float gain;
63    int sign;
64    int pos;
65    int orig;
66    float xy;
67    float yy;
68    float yp;
69 };
70
71 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
72    a combination of pulses such that its norm is still equal to 1. The only difference with 
73    the quantiser above is that the search is more complete. */
74 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
75 {
76    int L = 3;
77    float _y[L][N];
78    int _iy[L][N];
79    float _ny[L][N];
80    int _iny[L][N];
81    float *(ny[L]), *(y[L]);
82    int *(iny[L]), *(iy[L]);
83    int i, j, k, m;
84    int pulsesLeft;
85    float xy[L];
86    float yy[L];
87    float yp[L];
88    struct NBest _nbest[L];
89    struct NBest *(nbest[L]);
90    float Rpp=0, Rxp=0;
91    int maxL = 1;
92    
93    for (m=0;m<L;m++)
94       nbest[m] = &_nbest[m];
95    
96    for (m=0;m<L;m++)
97    {
98       ny[m] = _ny[m];
99       iny[m] = _iny[m];
100       y[m] = _y[m];
101       iy[m] = _iy[m];
102    }
103    
104    for (j=0;j<N;j++)
105    {
106       Rpp += p[j]*p[j];
107       Rxp += x[j]*p[j];
108    }
109    
110    /* We only need to initialise the zero because the first iteration only uses that */
111    for (i=0;i<N;i++)
112       y[0][i] = 0;
113    for (i=0;i<N;i++)
114       iy[0][i] = 0;
115    xy[0] = yy[0] = yp[0] = 0;
116
117    pulsesLeft = K;
118    while (pulsesLeft > 0)
119    {
120       int pulsesAtOnce=1;
121       int Lupdate = L;
122       int L2 = L;
123       
124       /* Decide on complexity strategy */
125       pulsesAtOnce = pulsesLeft/N;
126       if (pulsesAtOnce<1)
127          pulsesAtOnce = 1;
128       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
129          Lupdate = 1;
130       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
131       L2 = Lupdate;
132       if (L2>maxL)
133       {
134          L2 = maxL;
135          maxL *= N;
136       }
137
138       for (m=0;m<Lupdate;m++)
139          nbest[m]->score = -1e10f;
140
141       for (m=0;m<L2;m++)
142       {
143          for (j=0;j<N;j++)
144          {
145             int sign;
146             /*if (x[j]>0) sign=1; else sign=-1;*/
147             for (sign=-1;sign<=1;sign+=2)
148             {
149                /* All pulses at one location must have the same sign. */
150                if (iy[m][j]*sign < 0)
151                   continue;
152                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
153                float tmp_xy, tmp_yy, tmp_yp;
154                float score;
155                float g;
156                float s = sign*pulsesAtOnce;
157                
158                /* Updating the sums of the new pulse(s) */
159                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
160                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
161                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
162                
163                /* Compute the gain such that ||p + g*y|| = 1 */
164                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
165                /* Knowing that gain, what the error: (x-g*y)^2 
166                   (result is negated and we discard x^2 because it's constant) */
167                score = 2.f*g*tmp_xy - g*g*tmp_yy;
168
169                if (score>nbest[Lupdate-1]->score)
170                {
171                   int k;
172                   int id = Lupdate-1;
173                   struct NBest *tmp_best;
174
175                   /* Save some pointers that would be deleted and use them for the current entry*/
176                   tmp_best = nbest[Lupdate-1];
177                   while (id > 0 && score > nbest[id-1]->score)
178                      id--;
179                
180                   for (k=Lupdate-1;k>id;k--)
181                      nbest[k] = nbest[k-1];
182
183                   nbest[id] = tmp_best;
184                   nbest[id]->score = score;
185                   nbest[id]->pos = j;
186                   nbest[id]->orig = m;
187                   nbest[id]->sign = sign;
188                   nbest[id]->gain = g;
189                   nbest[id]->xy = tmp_xy;
190                   nbest[id]->yy = tmp_yy;
191                   nbest[id]->yp = tmp_yp;
192                }
193             }
194          }
195
196       }
197       /* Only now that we've made the final choice, update ny/iny and others */
198       for (k=0;k<Lupdate;k++)
199       {
200          int n;
201          int is;
202          float s;
203          is = nbest[k]->sign*pulsesAtOnce;
204          s = is;
205          for (n=0;n<N;n++)
206             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
207          ny[k][nbest[k]->pos] += s;
208
209          for (n=0;n<N;n++)
210             iny[k][n] = iy[nbest[k]->orig][n];
211          iny[k][nbest[k]->pos] += is;
212
213          xy[k] = nbest[k]->xy;
214          yy[k] = nbest[k]->yy;
215          yp[k] = nbest[k]->yp;
216       }
217       /* Swap ny/iny with y/iy */
218       for (k=0;k<Lupdate;k++)
219       {
220          float *tmp_ny;
221          int *tmp_iny;
222
223          tmp_ny = ny[k];
224          ny[k] = y[k];
225          y[k] = tmp_ny;
226          tmp_iny = iny[k];
227          iny[k] = iy[k];
228          iy[k] = tmp_iny;
229       }
230       pulsesLeft -= pulsesAtOnce;
231    }
232    
233    if (0) {
234       float err=0;
235       for (i=0;i<N;i++)
236          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
237       /*if (N<=10)
238         printf ("%f %d %d\n", err, K, N);*/
239    }
240    for (i=0;i<N;i++)
241       x[i] = p[i]+nbest[0]->gain*y[0][i];
242    /* Sanity checks, don't bother */
243    if (0) {
244       float E=1e-15;
245       int ABS = 0;
246       for (i=0;i<N;i++)
247          ABS += abs(iy[0][i]);
248       /*if (K != ABS)
249          printf ("%d %d\n", K, ABS);*/
250       for (i=0;i<N;i++)
251          E += x[i]*x[i];
252       /*printf ("%f\n", E);*/
253       E = 1/sqrt(E);
254       for (i=0;i<N;i++)
255          x[i] *= E;
256    }
257    
258    encode_pulses(iy[0], N, K, enc);
259    
260    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
261       due to the recursive computation used in quantisation.
262       Not quite sure whether we need that or not */
263    if (1) {
264       float Ryp=0;
265       float Ryy=0;
266       float g=0;
267       
268       for (i=0;i<N;i++)
269          Ryp += iy[0][i]*p[i];
270       
271       for (i=0;i<N;i++)
272          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
273       
274       Ryp = 0;
275       for (i=0;i<N;i++)
276          Ryp += y[0][i]*p[i];
277       
278       for (i=0;i<N;i++)
279          Ryy += y[0][i]*y[0][i];
280       
281       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
282         
283       for (i=0;i<N;i++)
284          x[i] = p[i] + g*y[0][i];
285       
286    }
287
288 }
289
290 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
291 {
292    int i;
293    int iy[N];
294    float y[N];
295    float Rpp=0, Ryp=0, Ryy=0;
296    float g;
297
298    decode_pulses(iy, N, K, dec);
299
300    /*for (i=0;i<N;i++)
301       printf ("%d ", iy[i]);*/
302    for (i=0;i<N;i++)
303       Rpp += p[i]*p[i];
304
305    for (i=0;i<N;i++)
306       Ryp += iy[i]*p[i];
307
308    for (i=0;i<N;i++)
309       y[i] = iy[i] - alpha*Ryp*p[i];
310
311    /* Recompute after the projection (I think it's right) */
312    Ryp = 0;
313    for (i=0;i<N;i++)
314       Ryp += y[i]*p[i];
315
316    for (i=0;i<N;i++)
317       Ryy += y[i]*y[i];
318
319    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
320
321    for (i=0;i<N;i++)
322       x[i] = p[i] + g*y[i];
323 }
324
325
326 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
327
328 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
329 {
330    int i,j;
331    int best=0;
332    float best_score=0;
333    float s = 1;
334    int sign;
335    float E;
336    int max_pos = N0-N/B;
337    if (max_pos > 32)
338       max_pos = 32;
339
340    for (i=0;i<max_pos*B;i+=B)
341    {
342       int j;
343       float xy=0, yy=0;
344       float score;
345       for (j=0;j<N;j++)
346       {
347          xy += x[j]*Y[i+j];
348          yy += Y[i+j]*Y[i+j];
349       }
350       score = xy*xy/(.001+yy);
351       if (score > best_score)
352       {
353          best_score = score;
354          best = i;
355          if (xy>0)
356             s = 1;
357          else
358             s = -1;
359       }
360    }
361    if (s<0)
362       sign = 1;
363    else
364       sign = 0;
365    /*printf ("%d %d ", sign, best);*/
366    ec_enc_uint(enc,sign,2);
367    ec_enc_uint(enc,best/B,max_pos);
368    /*printf ("%d %f\n", best, best_score);*/
369    
370    float pred_gain;
371    if (K>10)
372       pred_gain = pg[10];
373    else
374       pred_gain = pg[K];
375    E = 1e-10;
376    for (j=0;j<N;j++)
377    {
378       P[j] = s*Y[best+j];
379       E += P[j]*P[j];
380    }
381    E = pred_gain/sqrt(E);
382    for (j=0;j<N;j++)
383       P[j] *= E;
384    if (K>0)
385    {
386       for (j=0;j<N;j++)
387          x[j] -= P[j];
388    } else {
389       for (j=0;j<N;j++)
390          x[j] = P[j];
391    }
392    /*printf ("quant ");*/
393    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
394
395 }
396
397 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
398 {
399    int j;
400    int sign;
401    float s;
402    int best;
403    float E;
404    int max_pos = N0-N/B;
405    if (max_pos > 32)
406       max_pos = 32;
407    
408    sign = ec_dec_uint(dec, 2);
409    if (sign == 0)
410       s = 1;
411    else
412       s = -1;
413    
414    best = B*ec_dec_uint(dec, max_pos);
415    /*printf ("%d %d ", sign, best);*/
416
417    float pred_gain;
418    if (K>10)
419       pred_gain = pg[10];
420    else
421       pred_gain = pg[K];
422    E = 1e-10;
423    for (j=0;j<N;j++)
424    {
425       P[j] = s*Y[best+j];
426       E += P[j]*P[j];
427    }
428    E = pred_gain/sqrt(E);
429    for (j=0;j<N;j++)
430       P[j] *= E;
431    if (K==0)
432    {
433       for (j=0;j<N;j++)
434          x[j] = P[j];
435    }
436 }
437
438 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
439 {
440    int i, j;
441    float E;
442    
443    E = 1e-10;
444    if (N0 >= Nmax/2)
445    {
446       for (i=0;i<B;i++)
447       {
448          for (j=0;j<N/B;j++)
449          {
450             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
451             E += P[j*B+i]*P[j*B+i];
452          }
453       }
454    } else {
455       for (j=0;j<N;j++)
456       {
457          P[j] = Y[j];
458          E += P[j]*P[j];
459       }
460    }
461    E = 1.f/sqrt(E);
462    for (j=0;j<N;j++)
463       P[j] *= E;
464    for (j=0;j<N;j++)
465       x[j] = P[j];
466 }
467