fixed-point: (hopefully) last check-point for alg_quant() conversion
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /* Enable this or define your own implementation if you want to speed up the
45    VQ search (used in inner loop only) */
46 #if 0
47 #include <xmmintrin.h>
48 static inline float approx_sqrt(float x)
49 {
50    _mm_store_ss(&x, _mm_sqrt_ss(_mm_set_ss(x)));
51    return x;
52 }
53 static inline float approx_inv(float x)
54 {
55    _mm_store_ss(&x, _mm_rcp_ss(_mm_set_ss(x)));
56    return x;
57 }
58 #else
59 #define approx_sqrt(x) (sqrt(x))
60 #define approx_inv(x) (1.f/(x))
61 #endif
62
63 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
64    applies the compression in the pitch direction, computes the gain that will
65    give ||p+g*y||=1 and mixes the residual with the pitch. */
66 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha)
67 {
68    int i;
69    celt_word32_t Ryp, Ryy, Rpp;
70    celt_word32_t g;
71    VARDECL(celt_norm_t *y);
72 #ifdef FIXED_POINT
73    int yshift = 14-EC_ILOG(K);
74 #endif
75    ALLOC(y, N, celt_norm_t);
76
77    /*for (i=0;i<N;i++)
78    printf ("%d ", iy[i]);*/
79    Rpp = 0;
80    for (i=0;i<N;i++)
81       Rpp = MAC16_16(Rpp,P[i],P[i]);
82
83    Ryp = 0;
84    for (i=0;i<N;i++)
85       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
86
87    /* Remove part of the pitch component to compute the real residual from
88       the encoded (int) one */
89    for (i=0;i<N;i++)
90       y[i] = SUB16(SHL16(iy[i],yshift),
91                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
92
93    /* Recompute after the projection (I think it's right) */
94    Ryp = 0;
95    for (i=0;i<N;i++)
96       Ryp = MAC16_16(Ryp,y[i],P[i]);
97
98    Ryy = 0;
99    for (i=0;i<N;i++)
100       Ryy = MAC16_16(Ryy, y[i],y[i]);
101
102    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
103    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
104
105    for (i=0;i<N;i++)
106       X[i] = P[i] + MULT16_32_Q14(y[i], g);
107 }
108
109 /** All the info necessary to keep track of a hypothesis during the search */
110 struct NBest {
111    float score;
112    float gain;
113    int sign;
114    int pos;
115    int orig;
116    float xy;
117    float yy;
118    float yp;
119 };
120
121 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
122 {
123    int L = 3;
124    VARDECL(celt_norm_t *_y);
125    VARDECL(celt_norm_t *_ny);
126    VARDECL(int *_iy);
127    VARDECL(int *_iny);
128    VARDECL(celt_norm_t **y);
129    VARDECL(celt_norm_t **ny);
130    VARDECL(int **iy);
131    VARDECL(int **iny);
132    int i, j, k, m;
133    int pulsesLeft;
134    VARDECL(celt_word32_t *xy);
135    VARDECL(celt_word32_t *yy);
136    VARDECL(celt_word32_t *yp);
137    VARDECL(struct NBest *_nbest);
138    VARDECL(struct NBest **nbest);
139    celt_word32_t Rpp=0, Rxp=0;
140    int maxL = 1;
141 #ifdef FIXED_POINT
142    int yshift = 14-EC_ILOG(K);
143 #endif
144
145    ALLOC(_y, L*N, celt_norm_t);
146    ALLOC(_ny, L*N, celt_norm_t);
147    ALLOC(_iy, L*N, int);
148    ALLOC(_iny, L*N, int);
149    ALLOC(y, L*N, celt_norm_t*);
150    ALLOC(ny, L*N, celt_norm_t*);
151    ALLOC(iy, L*N, int*);
152    ALLOC(iny, L*N, int*);
153    
154    ALLOC(xy, L, celt_word32_t);
155    ALLOC(yy, L, celt_word32_t);
156    ALLOC(yp, L, celt_word32_t);
157    ALLOC(_nbest, L, struct NBest);
158    ALLOC(nbest, L, struct NBest *);
159    
160    for (m=0;m<L;m++)
161       nbest[m] = &_nbest[m];
162    
163    for (m=0;m<L;m++)
164    {
165       ny[m] = &_ny[m*N];
166       iny[m] = &_iny[m*N];
167       y[m] = &_y[m*N];
168       iy[m] = &_iy[m*N];
169    }
170    
171    for (j=0;j<N;j++)
172    {
173       Rpp = MAC16_16(Rpp, P[j],P[j]);
174       Rxp = MAC16_16(Rxp, X[j],P[j]);
175    }
176    Rpp = ROUND(Rpp, NORM_SHIFT);
177    Rxp = ROUND(Rxp, NORM_SHIFT);
178    if (Rpp>NORM_SCALING)
179       celt_fatal("Rpp > 1");
180
181    /* We only need to initialise the zero because the first iteration only uses that */
182    for (i=0;i<N;i++)
183       y[0][i] = 0;
184    for (i=0;i<N;i++)
185       iy[0][i] = 0;
186    xy[0] = yy[0] = yp[0] = 0;
187
188    pulsesLeft = K;
189    while (pulsesLeft > 0)
190    {
191       int pulsesAtOnce=1;
192       int Lupdate = L;
193       int L2 = L;
194       
195       /* Decide on complexity strategy */
196       pulsesAtOnce = pulsesLeft/N;
197       if (pulsesAtOnce<1)
198          pulsesAtOnce = 1;
199       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
200          Lupdate = 1;
201       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
202       L2 = Lupdate;
203       if (L2>maxL)
204       {
205          L2 = maxL;
206          maxL *= N;
207       }
208
209       for (m=0;m<Lupdate;m++)
210          nbest[m]->score = -1e10f;
211
212       for (m=0;m<L2;m++)
213       {
214          for (j=0;j<N;j++)
215          {
216             int sign;
217             /*if (x[j]>0) sign=1; else sign=-1;*/
218             for (sign=-1;sign<=1;sign+=2)
219             {
220                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
221                celt_word32_t tmp_xy, tmp_yy, tmp_yp;
222                celt_word16_t spj, aspj;
223                float score;
224                float g;
225                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
226                
227                /* All pulses at one location must have the same sign. */
228                if (iy[m][j]*sign < 0)
229                   continue;
230
231                spj = MULT16_16_P14(s, P[j]);
232                aspj = MULT16_16_P15(alpha, spj);
233                /* Updating the sums of the new pulse(s) */
234                tmp_xy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
235                tmp_yy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
236                tmp_yp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
237                
238                /* Compute the gain such that ||p + g*y|| = 1 */
239                g = (approx_sqrt(NORM_SCALING_1*NORM_SCALING_1*tmp_yp*tmp_yp + tmp_yy - NORM_SCALING_1*tmp_yy*Rpp) - tmp_yp*NORM_SCALING_1)*approx_inv(tmp_yy);
240                /* Knowing that gain, what the error: (x-g*y)^2 
241                   (result is negated and we discard x^2 because it's constant) */
242                score = 2.f*g*tmp_xy*NORM_SCALING_1 - g*g*tmp_yy;
243
244                if (score>nbest[Lupdate-1]->score)
245                {
246                   int k;
247                   int id = Lupdate-1;
248                   struct NBest *tmp_best;
249
250                   /* Save some pointers that would be deleted and use them for the current entry*/
251                   tmp_best = nbest[Lupdate-1];
252                   while (id > 0 && score > nbest[id-1]->score)
253                      id--;
254                
255                   for (k=Lupdate-1;k>id;k--)
256                      nbest[k] = nbest[k-1];
257
258                   nbest[id] = tmp_best;
259                   nbest[id]->score = score;
260                   nbest[id]->pos = j;
261                   nbest[id]->orig = m;
262                   nbest[id]->sign = sign;
263                   nbest[id]->gain = g;
264                   nbest[id]->xy = tmp_xy;
265                   nbest[id]->yy = tmp_yy;
266                   nbest[id]->yp = tmp_yp;
267                }
268             }
269          }
270
271       }
272       
273       if (!(nbest[0]->score > -1e10f))
274          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
275       /* Only now that we've made the final choice, update ny/iny and others */
276       for (k=0;k<Lupdate;k++)
277       {
278          int n;
279          int is;
280          celt_norm_t s;
281          is = nbest[k]->sign*pulsesAtOnce;
282          s = SHL16(is, yshift);
283          for (n=0;n<N;n++)
284             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
285          ny[k][nbest[k]->pos] += s;
286
287          for (n=0;n<N;n++)
288             iny[k][n] = iy[nbest[k]->orig][n];
289          iny[k][nbest[k]->pos] += is;
290
291          xy[k] = nbest[k]->xy;
292          yy[k] = nbest[k]->yy;
293          yp[k] = nbest[k]->yp;
294       }
295       /* Swap ny/iny with y/iy */
296       for (k=0;k<Lupdate;k++)
297       {
298          celt_norm_t *tmp_ny;
299          int *tmp_iny;
300
301          tmp_ny = ny[k];
302          ny[k] = y[k];
303          y[k] = tmp_ny;
304          tmp_iny = iny[k];
305          iny[k] = iy[k];
306          iy[k] = tmp_iny;
307       }
308       pulsesLeft -= pulsesAtOnce;
309    }
310    
311 #if 0
312    if (0) {
313       float err=0;
314       for (i=0;i<N;i++)
315          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
316       /*if (N<=10)
317         printf ("%f %d %d\n", err, K, N);*/
318    }
319    /* Sanity checks, don't bother */
320    if (0) {
321       for (i=0;i<N;i++)
322          x[i] = p[i]+nbest[0]->gain*y[0][i];
323       float E=1e-15;
324       int ABS = 0;
325       for (i=0;i<N;i++)
326          ABS += abs(iy[0][i]);
327       /*if (K != ABS)
328          printf ("%d %d\n", K, ABS);*/
329       for (i=0;i<N;i++)
330          E += x[i]*x[i];
331       /*printf ("%f\n", E);*/
332       E = 1/sqrt(E);
333       for (i=0;i<N;i++)
334          x[i] *= E;
335    }
336 #endif
337    
338    encode_pulses(iy[0], N, K, enc);
339    
340    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
341       due to the recursive computation used in quantisation.
342       Not quite sure whether we need that or not */
343    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
344 }
345
346 /** Decode pulse vector and combine the result with the pitch vector to produce
347     the final normalised signal in the current band. */
348 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
349 {
350    VARDECL(int *iy);
351    ALLOC(iy, N, int);
352    decode_pulses(iy, N, K, dec);
353    mix_pitch_and_residual(iy, X, N, K, P, alpha);
354 }
355
356
357 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
358
359 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
360 {
361    int i,j;
362    int best=0;
363    float best_score=0;
364    float s = 1;
365    int sign;
366    float E;
367    float pred_gain;
368    int max_pos = N0-N/B;
369    if (max_pos > 32)
370       max_pos = 32;
371
372    for (i=0;i<max_pos*B;i+=B)
373    {
374       int j;
375       float xy=0, yy=0;
376       float score;
377       for (j=0;j<N;j++)
378       {
379          xy += 1.f*x[j]*Y[i+N-j-1];
380          yy += 1.f*Y[i+N-j-1]*Y[i+N-j-1];
381       }
382       score = xy*xy/(.001+yy);
383       if (score > best_score)
384       {
385          best_score = score;
386          best = i;
387          if (xy>0)
388             s = 1;
389          else
390             s = -1;
391       }
392    }
393    if (s<0)
394       sign = 1;
395    else
396       sign = 0;
397    /*printf ("%d %d ", sign, best);*/
398    ec_enc_uint(enc,sign,2);
399    ec_enc_uint(enc,best/B,max_pos);
400    /*printf ("%d %f\n", best, best_score);*/
401    
402    if (K>10)
403       pred_gain = pg[10];
404    else
405       pred_gain = pg[K];
406    E = 1e-10;
407    for (j=0;j<N;j++)
408    {
409       P[j] = s*Y[best+N-j-1];
410       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
411    }
412    E = pred_gain/sqrt(E);
413    for (j=0;j<N;j++)
414       P[j] *= E;
415    if (K>0)
416    {
417       for (j=0;j<N;j++)
418          x[j] -= P[j];
419    } else {
420       for (j=0;j<N;j++)
421          x[j] = P[j];
422    }
423    /*printf ("quant ");*/
424    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
425
426 }
427
428 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
429 {
430    int j;
431    int sign;
432    float s;
433    int best;
434    float E;
435    float pred_gain;
436    int max_pos = N0-N/B;
437    if (max_pos > 32)
438       max_pos = 32;
439    
440    sign = ec_dec_uint(dec, 2);
441    if (sign == 0)
442       s = 1;
443    else
444       s = -1;
445    
446    best = B*ec_dec_uint(dec, max_pos);
447    /*printf ("%d %d ", sign, best);*/
448
449    if (K>10)
450       pred_gain = pg[10];
451    else
452       pred_gain = pg[K];
453    E = 1e-10;
454    for (j=0;j<N;j++)
455    {
456       P[j] = s*Y[best+N-j-1];
457       E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
458    }
459    E = pred_gain/sqrt(E);
460    for (j=0;j<N;j++)
461       P[j] *= E;
462    if (K==0)
463    {
464       for (j=0;j<N;j++)
465          x[j] = P[j];
466    }
467 }
468
469 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
470 {
471    int i, j;
472    float E;
473    
474    E = 1e-10;
475    if (N0 >= Nmax/2)
476    {
477       for (i=0;i<B;i++)
478       {
479          for (j=0;j<N/B;j++)
480          {
481             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
482             E += NORM_SCALING_1*NORM_SCALING_1*P[j*B+i]*P[j*B+i];
483          }
484       }
485    } else {
486       for (j=0;j<N;j++)
487       {
488          P[j] = Y[j];
489          E += NORM_SCALING_1*NORM_SCALING_1*P[j]*P[j];
490       }
491    }
492    E = 1.f/sqrt(E);
493    for (j=0;j<N;j++)
494       P[j] *= E;
495    for (j=0;j<N;j++)
496       x[j] = P[j];
497 }
498