Everything should now compile with a C89 compiler.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41
42 /* Enable this or define your own implementation if you want to speed up the
43    VQ search (used in inner loop only) */
44 #if 0
45 #include <xmmintrin.h>
46 static inline float approx_sqrt(float x)
47 {
48    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
49    return x;
50 }
51 static inline float approx_inv(float x)
52 {
53    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
54    return x;
55 }
56 #else
57 #define approx_sqrt(x) (sqrt(x))
58 #define approx_inv(x) (1.f/(x))
59 #endif
60
61 struct NBest {
62    float score;
63    float gain;
64    int sign;
65    int pos;
66    int orig;
67    float xy;
68    float yy;
69    float yp;
70 };
71
72 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
73    a combination of pulses such that its norm is still equal to 1. The only difference with 
74    the quantiser above is that the search is more complete. */
75 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
76 {
77    int L = 3;
78    VARDECL(float *_y);
79    VARDECL(float *_ny);
80    VARDECL(int *_iy);
81    VARDECL(int *_iny);
82    VARDECL(float **y);
83    VARDECL(float **ny);
84    VARDECL(int **iy);
85    VARDECL(int **iny);
86    int i, j, k, m;
87    int pulsesLeft;
88    VARDECL(float *xy);
89    VARDECL(float *yy);
90    VARDECL(float *yp);
91    VARDECL(struct NBest *_nbest);
92    VARDECL(struct NBest **nbest);
93    float Rpp=0, Rxp=0;
94    int maxL = 1;
95    
96    ALLOC(_y, L*N, float);
97    ALLOC(_ny, L*N, float);
98    ALLOC(_iy, L*N, int);
99    ALLOC(_iny, L*N, int);
100    ALLOC(y, L*N, float*);
101    ALLOC(ny, L*N, float*);
102    ALLOC(iy, L*N, int*);
103    ALLOC(iny, L*N, int*);
104    
105    ALLOC(xy, L, float);
106    ALLOC(yy, L, float);
107    ALLOC(yp, L, float);
108    ALLOC(_nbest, L, struct NBest);
109    ALLOC(nbest, L, struct NBest *);
110
111    for (m=0;m<L;m++)
112       nbest[m] = &_nbest[m];
113    
114    for (m=0;m<L;m++)
115    {
116       ny[m] = &_ny[m*N];
117       iny[m] = &_iny[m*N];
118       y[m] = &_y[m*N];
119       iy[m] = &_iy[m*N];
120    }
121    
122    for (j=0;j<N;j++)
123    {
124       Rpp += p[j]*p[j];
125       Rxp += x[j]*p[j];
126    }
127    
128    /* We only need to initialise the zero because the first iteration only uses that */
129    for (i=0;i<N;i++)
130       y[0][i] = 0;
131    for (i=0;i<N;i++)
132       iy[0][i] = 0;
133    xy[0] = yy[0] = yp[0] = 0;
134
135    pulsesLeft = K;
136    while (pulsesLeft > 0)
137    {
138       int pulsesAtOnce=1;
139       int Lupdate = L;
140       int L2 = L;
141       
142       /* Decide on complexity strategy */
143       pulsesAtOnce = pulsesLeft/N;
144       if (pulsesAtOnce<1)
145          pulsesAtOnce = 1;
146       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
147          Lupdate = 1;
148       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
149       L2 = Lupdate;
150       if (L2>maxL)
151       {
152          L2 = maxL;
153          maxL *= N;
154       }
155
156       for (m=0;m<Lupdate;m++)
157          nbest[m]->score = -1e10f;
158
159       for (m=0;m<L2;m++)
160       {
161          for (j=0;j<N;j++)
162          {
163             int sign;
164             /*if (x[j]>0) sign=1; else sign=-1;*/
165             for (sign=-1;sign<=1;sign+=2)
166             {
167                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
168                float tmp_xy, tmp_yy, tmp_yp;
169                float score;
170                float g;
171                float s = sign*pulsesAtOnce;
172                
173                /* All pulses at one location must have the same sign. */
174                if (iy[m][j]*sign < 0)
175                   continue;
176
177                /* Updating the sums of the new pulse(s) */
178                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
179                tmp_yy = yy[m] + 2.f*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2.f*alpha*s*p[j]*yp[m] - 2.f*s*s*alpha*p[j]*p[j];
180                tmp_yp = yp[m] + s*p[j]               *(1.f-alpha*Rpp);
181                
182                /* Compute the gain such that ||p + g*y|| = 1 */
183                g = (approx_sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)*approx_inv(tmp_yy);
184                /* Knowing that gain, what the error: (x-g*y)^2 
185                   (result is negated and we discard x^2 because it's constant) */
186                score = 2.f*g*tmp_xy - g*g*tmp_yy;
187
188                if (score>nbest[Lupdate-1]->score)
189                {
190                   int k;
191                   int id = Lupdate-1;
192                   struct NBest *tmp_best;
193
194                   /* Save some pointers that would be deleted and use them for the current entry*/
195                   tmp_best = nbest[Lupdate-1];
196                   while (id > 0 && score > nbest[id-1]->score)
197                      id--;
198                
199                   for (k=Lupdate-1;k>id;k--)
200                      nbest[k] = nbest[k-1];
201
202                   nbest[id] = tmp_best;
203                   nbest[id]->score = score;
204                   nbest[id]->pos = j;
205                   nbest[id]->orig = m;
206                   nbest[id]->sign = sign;
207                   nbest[id]->gain = g;
208                   nbest[id]->xy = tmp_xy;
209                   nbest[id]->yy = tmp_yy;
210                   nbest[id]->yp = tmp_yp;
211                }
212             }
213          }
214
215       }
216       /* Only now that we've made the final choice, update ny/iny and others */
217       for (k=0;k<Lupdate;k++)
218       {
219          int n;
220          int is;
221          float s;
222          is = nbest[k]->sign*pulsesAtOnce;
223          s = is;
224          for (n=0;n<N;n++)
225             ny[k][n] = y[nbest[k]->orig][n] - alpha*s*p[nbest[k]->pos]*p[n];
226          ny[k][nbest[k]->pos] += s;
227
228          for (n=0;n<N;n++)
229             iny[k][n] = iy[nbest[k]->orig][n];
230          iny[k][nbest[k]->pos] += is;
231
232          xy[k] = nbest[k]->xy;
233          yy[k] = nbest[k]->yy;
234          yp[k] = nbest[k]->yp;
235       }
236       /* Swap ny/iny with y/iy */
237       for (k=0;k<Lupdate;k++)
238       {
239          float *tmp_ny;
240          int *tmp_iny;
241
242          tmp_ny = ny[k];
243          ny[k] = y[k];
244          y[k] = tmp_ny;
245          tmp_iny = iny[k];
246          iny[k] = iy[k];
247          iy[k] = tmp_iny;
248       }
249       pulsesLeft -= pulsesAtOnce;
250    }
251    
252    if (0) {
253       float err=0;
254       for (i=0;i<N;i++)
255          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
256       /*if (N<=10)
257         printf ("%f %d %d\n", err, K, N);*/
258    }
259    for (i=0;i<N;i++)
260       x[i] = p[i]+nbest[0]->gain*y[0][i];
261    /* Sanity checks, don't bother */
262    if (0) {
263       float E=1e-15;
264       int ABS = 0;
265       for (i=0;i<N;i++)
266          ABS += abs(iy[0][i]);
267       /*if (K != ABS)
268          printf ("%d %d\n", K, ABS);*/
269       for (i=0;i<N;i++)
270          E += x[i]*x[i];
271       /*printf ("%f\n", E);*/
272       E = 1/sqrt(E);
273       for (i=0;i<N;i++)
274          x[i] *= E;
275    }
276    
277    encode_pulses(iy[0], N, K, enc);
278    
279    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
280       due to the recursive computation used in quantisation.
281       Not quite sure whether we need that or not */
282    if (1) {
283       float Ryp=0;
284       float Ryy=0;
285       float g=0;
286       
287       for (i=0;i<N;i++)
288          Ryp += iy[0][i]*p[i];
289       
290       for (i=0;i<N;i++)
291          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
292       
293       Ryp = 0;
294       for (i=0;i<N;i++)
295          Ryp += y[0][i]*p[i];
296       
297       for (i=0;i<N;i++)
298          Ryy += y[0][i]*y[0][i];
299       
300       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
301         
302       for (i=0;i<N;i++)
303          x[i] = p[i] + g*y[0][i];
304       
305    }
306
307 }
308
309 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
310 {
311    int i;
312    float Rpp=0, Ryp=0, Ryy=0;
313    float g;
314    VARDECL(int *iy);
315    VARDECL(float *y);
316    
317    ALLOC(iy, N, int);
318    ALLOC(y, N, float);
319
320    decode_pulses(iy, N, K, dec);
321
322    /*for (i=0;i<N;i++)
323       printf ("%d ", iy[i]);*/
324    for (i=0;i<N;i++)
325       Rpp += p[i]*p[i];
326
327    for (i=0;i<N;i++)
328       Ryp += iy[i]*p[i];
329
330    for (i=0;i<N;i++)
331       y[i] = iy[i] - alpha*Ryp*p[i];
332
333    /* Recompute after the projection (I think it's right) */
334    Ryp = 0;
335    for (i=0;i<N;i++)
336       Ryp += y[i]*p[i];
337
338    for (i=0;i<N;i++)
339       Ryy += y[i]*y[i];
340
341    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
342
343    for (i=0;i<N;i++)
344       x[i] = p[i] + g*y[i];
345 }
346
347
348 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
349
350 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
351 {
352    int i,j;
353    int best=0;
354    float best_score=0;
355    float s = 1;
356    int sign;
357    float E;
358    float pred_gain;
359    int max_pos = N0-N/B;
360    if (max_pos > 32)
361       max_pos = 32;
362
363    for (i=0;i<max_pos*B;i+=B)
364    {
365       int j;
366       float xy=0, yy=0;
367       float score;
368       for (j=0;j<N;j++)
369       {
370          xy += x[j]*Y[i+j];
371          yy += Y[i+j]*Y[i+j];
372       }
373       score = xy*xy/(.001+yy);
374       if (score > best_score)
375       {
376          best_score = score;
377          best = i;
378          if (xy>0)
379             s = 1;
380          else
381             s = -1;
382       }
383    }
384    if (s<0)
385       sign = 1;
386    else
387       sign = 0;
388    /*printf ("%d %d ", sign, best);*/
389    ec_enc_uint(enc,sign,2);
390    ec_enc_uint(enc,best/B,max_pos);
391    /*printf ("%d %f\n", best, best_score);*/
392    
393    if (K>10)
394       pred_gain = pg[10];
395    else
396       pred_gain = pg[K];
397    E = 1e-10;
398    for (j=0;j<N;j++)
399    {
400       P[j] = s*Y[best+j];
401       E += P[j]*P[j];
402    }
403    E = pred_gain/sqrt(E);
404    for (j=0;j<N;j++)
405       P[j] *= E;
406    if (K>0)
407    {
408       for (j=0;j<N;j++)
409          x[j] -= P[j];
410    } else {
411       for (j=0;j<N;j++)
412          x[j] = P[j];
413    }
414    /*printf ("quant ");*/
415    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
416
417 }
418
419 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
420 {
421    int j;
422    int sign;
423    float s;
424    int best;
425    float E;
426    float pred_gain;
427    int max_pos = N0-N/B;
428    if (max_pos > 32)
429       max_pos = 32;
430    
431    sign = ec_dec_uint(dec, 2);
432    if (sign == 0)
433       s = 1;
434    else
435       s = -1;
436    
437    best = B*ec_dec_uint(dec, max_pos);
438    /*printf ("%d %d ", sign, best);*/
439
440    if (K>10)
441       pred_gain = pg[10];
442    else
443       pred_gain = pg[K];
444    E = 1e-10;
445    for (j=0;j<N;j++)
446    {
447       P[j] = s*Y[best+j];
448       E += P[j]*P[j];
449    }
450    E = pred_gain/sqrt(E);
451    for (j=0;j<N;j++)
452       P[j] *= E;
453    if (K==0)
454    {
455       for (j=0;j<N;j++)
456          x[j] = P[j];
457    }
458 }
459
460 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
461 {
462    int i, j;
463    float E;
464    
465    E = 1e-10;
466    if (N0 >= Nmax/2)
467    {
468       for (i=0;i<B;i++)
469       {
470          for (j=0;j<N/B;j++)
471          {
472             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
473             E += P[j*B+i]*P[j*B+i];
474          }
475       }
476    } else {
477       for (j=0;j<N;j++)
478       {
479          P[j] = Y[j];
480          E += P[j]*P[j];
481       }
482    }
483    E = 1.f/sqrt(E);
484    for (j=0;j<N;j++)
485       P[j] *= E;
486    for (j=0;j<N;j++)
487       x[j] = P[j];
488 }
489