fixed-point: The cross-products in alg_quant() are now all converted.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift = 14-EC_ILOG(K);
74 #endif
75    ALLOC(y, N, celt_norm_t);
76
77    /*for (i=0;i<N;i++)
78    printf ("%d ", iy[i]);*/
79    Rpp = 0;
80    for (i=0;i<N;i++)
81       Rpp = MAC16_16(Rpp,P[i],P[i]);
82
83    Ryp = 0;
84    for (i=0;i<N;i++)
85       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
86
87    /* Remove part of the pitch component to compute the real residual from
88       the encoded (int) one */
89    for (i=0;i<N;i++)
90       y[i] = SUB16(SHL16(iy[i],yshift),
91                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
92
93    /* Recompute after the projection (I think it's right) */
94    Ryp = 0;
95    for (i=0;i<N;i++)
96       Ryp = MAC16_16(Ryp,y[i],P[i]);
97
98    Ryy = 0;
99    for (i=0;i<N;i++)
100       Ryy = MAC16_16(Ryy, y[i],y[i]);
101
102    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
103    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
104
105    for (i=0;i<N;i++)
106       X[i] = P[i] + MULT16_32_Q14(y[i], g);
107 }
108
109 /** All the info necessary to keep track of a hypothesis during the search */
110 struct NBest {
111    float score;
112    float gain;
113    int sign;
114    int pos;
115    int orig;
116    float xy;
117    float yy;
118    float yp;
119 };
120
121 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
122 {
123    int L = 3;
124    VARDECL(float *p);
125    VARDECL(float *_y);
126    VARDECL(float *_ny);
127    VARDECL(int *_iy);
128    VARDECL(int *_iny);
129    VARDECL(float **y);
130    VARDECL(float **ny);
131    VARDECL(int **iy);
132    VARDECL(int **iny);
133    int i, j, k, m;
134    int pulsesLeft;
135    VARDECL(celt_word32_t *xy);
136    VARDECL(celt_word32_t *yy);
137    VARDECL(celt_word32_t *yp);
138    VARDECL(struct NBest *_nbest);
139    VARDECL(struct NBest **nbest);
140    celt_word32_t Rpp=0, Rxp=0;
141    int maxL = 1;
142    float _alpha = Q15_ONE_1*alpha;
143 #ifdef FIXED_POINT
144    int yshift = 14-EC_ILOG(K);
145 #endif
146
147    ALLOC(p, N, float);
148    ALLOC(_y, L*N, float);
149    ALLOC(_ny, L*N, float);
150    ALLOC(_iy, L*N, int);
151    ALLOC(_iny, L*N, int);
152    ALLOC(y, L*N, float*);
153    ALLOC(ny, L*N, float*);
154    ALLOC(iy, L*N, int*);
155    ALLOC(iny, L*N, int*);
156    
157    ALLOC(xy, L, celt_word32_t);
158    ALLOC(yy, L, celt_word32_t);
159    ALLOC(yp, L, celt_word32_t);
160    ALLOC(_nbest, L, struct NBest);
161    ALLOC(nbest, L, struct NBest *);
162
163    for (j=0;j<N;j++)
164    {
165       p[j] = P[j]*NORM_SCALING_1;
166    }
167    
168    for (m=0;m<L;m++)
169       nbest[m] = &_nbest[m];
170    
171    for (m=0;m<L;m++)
172    {
173       ny[m] = &_ny[m*N];
174       iny[m] = &_iny[m*N];
175       y[m] = &_y[m*N];
176       iy[m] = &_iy[m*N];
177    }
178    
179    for (j=0;j<N;j++)
180    {
181       Rpp = MAC16_16(Rpp, P[j],P[j]);
182       Rxp = MAC16_16(Rxp, X[j],P[j]);
183    }
184    Rpp = ROUND(Rpp, NORM_SHIFT);
185    Rxp = ROUND(Rxp, NORM_SHIFT);
186    if (Rpp>NORM_SCALING)
187       celt_fatal("Rpp > 1");
188
189    /* We only need to initialise the zero because the first iteration only uses that */
190    for (i=0;i<N;i++)
191       y[0][i] = 0;
192    for (i=0;i<N;i++)
193       iy[0][i] = 0;
194    xy[0] = yy[0] = yp[0] = 0;
195
196    pulsesLeft = K;
197    while (pulsesLeft > 0)
198    {
199       int pulsesAtOnce=1;
200       int Lupdate = L;
201       int L2 = L;
202       
203       /* Decide on complexity strategy */
204       pulsesAtOnce = pulsesLeft/N;
205       if (pulsesAtOnce<1)
206          pulsesAtOnce = 1;
207       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
208          Lupdate = 1;
209       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
210       L2 = Lupdate;
211       if (L2>maxL)
212       {
213          L2 = maxL;
214          maxL *= N;
215       }
216
217       for (m=0;m<Lupdate;m++)
218          nbest[m]->score = -1e10f;
219
220       for (m=0;m<L2;m++)
221       {
222          for (j=0;j<N;j++)
223          {
224             int sign;
225             /*if (x[j]>0) sign=1; else sign=-1;*/
226             for (sign=-1;sign<=1;sign+=2)
227             {
228                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
229                celt_word32_t tmp_xy, tmp_yy, tmp_yp;
230                celt_word16_t spj, aspj;
231                float score;
232                float g;
233                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
234                
235                /* All pulses at one location must have the same sign. */
236                if (iy[m][j]*sign < 0)
237                   continue;
238
239                spj = MULT16_16_P14(s, P[j]);
240                aspj = MULT16_16_P15(alpha, spj);
241                /* Updating the sums of the new pulse(s) */
242                tmp_xy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
243                tmp_yy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
244                tmp_yp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
245                
246                /* Compute the gain such that ||p + g*y|| = 1 */
247                g = (approx_sqrt(NORM_SCALING_1*NORM_SCALING_1*tmp_yp*tmp_yp + tmp_yy - NORM_SCALING_1*tmp_yy*Rpp) - tmp_yp*NORM_SCALING_1)*approx_inv(tmp_yy);
248                /* Knowing that gain, what the error: (x-g*y)^2 
249                   (result is negated and we discard x^2 because it's constant) */
250                score = 2.f*g*tmp_xy*NORM_SCALING_1 - g*g*tmp_yy;
251
252                if (score>nbest[Lupdate-1]->score)
253                {
254                   int k;
255                   int id = Lupdate-1;
256                   struct NBest *tmp_best;
257
258                   /* Save some pointers that would be deleted and use them for the current entry*/
259                   tmp_best = nbest[Lupdate-1];
260                   while (id > 0 && score > nbest[id-1]->score)
261                      id--;
262                
263                   for (k=Lupdate-1;k>id;k--)
264                      nbest[k] = nbest[k-1];
265
266                   nbest[id] = tmp_best;
267                   nbest[id]->score = score;
268                   nbest[id]->pos = j;
269                   nbest[id]->orig = m;
270                   nbest[id]->sign = sign;
271                   nbest[id]->gain = g;
272                   nbest[id]->xy = tmp_xy;
273                   nbest[id]->yy = tmp_yy;
274                   nbest[id]->yp = tmp_yp;
275                }
276             }
277          }
278
279       }
280       
281       if (!(nbest[0]->score > -1e10f))
282          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
283       /* Only now that we've made the final choice, update ny/iny and others */
284       for (k=0;k<Lupdate;k++)
285       {
286          int n;
287          int is;
288          float s;
289          is = nbest[k]->sign*pulsesAtOnce;
290          s = SHL16(is, yshift);
291          for (n=0;n<N;n++)
292             ny[k][n] = y[nbest[k]->orig][n] - _alpha*MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n]));
293          ny[k][nbest[k]->pos] += s;
294
295          for (n=0;n<N;n++)
296             iny[k][n] = iy[nbest[k]->orig][n];
297          iny[k][nbest[k]->pos] += is;
298
299          xy[k] = nbest[k]->xy;
300          yy[k] = nbest[k]->yy;
301          yp[k] = nbest[k]->yp;
302       }
303       /* Swap ny/iny with y/iy */
304       for (k=0;k<Lupdate;k++)
305       {
306          float *tmp_ny;
307          int *tmp_iny;
308
309          tmp_ny = ny[k];
310          ny[k] = y[k];
311          y[k] = tmp_ny;
312          tmp_iny = iny[k];
313          iny[k] = iy[k];
314          iy[k] = tmp_iny;
315       }
316       pulsesLeft -= pulsesAtOnce;
317    }
318    
319 #if 0
320    if (0) {
321       float err=0;
322       for (i=0;i<N;i++)
323          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
324       /*if (N<=10)
325         printf ("%f %d %d\n", err, K, N);*/
326    }
327    /* Sanity checks, don't bother */
328    if (0) {
329       for (i=0;i<N;i++)
330          x[i] = p[i]+nbest[0]->gain*y[0][i];
331       float E=1e-15;
332       int ABS = 0;
333       for (i=0;i<N;i++)
334          ABS += abs(iy[0][i]);
335       /*if (K != ABS)
336          printf ("%d %d\n", K, ABS);*/
337       for (i=0;i<N;i++)
338          E += x[i]*x[i];
339       /*printf ("%f\n", E);*/
340       E = 1/sqrt(E);
341       for (i=0;i<N;i++)
342          x[i] *= E;
343    }
344 #endif
345    
346    encode_pulses(iy[0], N, K, enc);
347    
348    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
349       due to the recursive computation used in quantisation.
350       Not quite sure whether we need that or not */
351    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
352 }
353
354 /** Decode pulse vector and combine the result with the pitch vector to produce
355     the final normalised signal in the current band. */
356 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
357 {
358    VARDECL(int *iy);
359    ALLOC(iy, N, int);
360    decode_pulses(iy, N, K, dec);
361    mix_pitch_and_residual(iy, X, N, K, P, alpha);
362 }
363
364
365 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
366
367 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
368 {
369    int i,j;
370    int best=0;
371    float best_score=0;
372    float s = 1;
373    int sign;
374    float E;
375    float pred_gain;
376    int max_pos = N0-N/B;
377    if (max_pos > 32)
378       max_pos = 32;
379
380    for (i=0;i<max_pos*B;i+=B)
381    {
382       int j;
383       float xy=0, yy=0;
384       float score;
385       for (j=0;j<N;j++)
386       {
387          xy += 1.f*x[j]*Y[i+N-j-1];
388          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
389       }
390       score = xy*xy/(.001+yy);
391       if (score > best_score)
392       {
393          best_score = score;
394          best = i;
395          if (xy>0)
396             s = 1;
397          else
398             s = -1;
399       }
400    }
401    if (s<0)
402       sign = 1;
403    else
404       sign = 0;
405    /*printf ("%d %d ", sign, best);*/
406    ec_enc_uint(enc,sign,2);
407    ec_enc_uint(enc,best/B,max_pos);
408    /*printf ("%d %f\n", best, best_score);*/
409    
410    if (K>10)
411       pred_gain = pg[10];
412    else
413       pred_gain = pg[K];
414    E = 1e-10;
415    for (j=0;j<N;j++)
416    {
417       P[j] = s*Y[best+N-j-1];
418       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
419    }
420    E = pred_gain/sqrt(E);
421    for (j=0;j<N;j++)
422       P[j] *= E;
423    if (K>0)
424    {
425       for (j=0;j<N;j++)
426          x[j] -= P[j];
427    } else {
428       for (j=0;j<N;j++)
429          x[j] = P[j];
430    }
431    /*printf ("quant ");*/
432    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
433
434 }
435
436 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
437 {
438    int j;
439    int sign;
440    float s;
441    int best;
442    float E;
443    float pred_gain;
444    int max_pos = N0-N/B;
445    if (max_pos > 32)
446       max_pos = 32;
447    
448    sign = ec_dec_uint(dec, 2);
449    if (sign == 0)
450       s = 1;
451    else
452       s = -1;
453    
454    best = B*ec_dec_uint(dec, max_pos);
455    /*printf ("%d %d ", sign, best);*/
456
457    if (K>10)
458       pred_gain = pg[10];
459    else
460       pred_gain = pg[K];
461    E = 1e-10;
462    for (j=0;j<N;j++)
463    {
464       P[j] = s*Y[best+N-j-1];
465       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
466    }
467    E = pred_gain/sqrt(E);
468    for (j=0;j<N;j++)
469       P[j] *= E;
470    if (K==0)
471    {
472       for (j=0;j<N;j++)
473          x[j] = P[j];
474    }
475 }
476
477 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
478 {
479    int i, j;
480    float E;
481    
482    E = 1e-10;
483    if (N0 >= Nmax/2)
484    {
485       for (i=0;i<B;i++)
486       {
487          for (j=0;j<N/B;j++)
488          {
489             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
490             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
491          }
492       }
493    } else {
494       for (j=0;j<N;j++)
495       {
496          P[j] = Y[j];
497          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
498       }
499    }
500    E = 1.f/sqrt(E);
501    for (j=0;j<N;j++)
502       P[j] *= E;
503    for (j=0;j<N;j++)
504       x[j] = P[j];
505 }
506