Bit of cleaning up. No real code change (well, I hope so!).
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
45    applies the compression in the pitch direction, computes the gain that will
46    give ||p+g*y||=1 and mixes the residual with the pitch. */
47 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
48 {
49    int i;
50    celt_word32_t Ryp, Ryy, Rpp;
51    celt_word32_t g;
52    VARDECL(celt_norm_t *y);
53 #ifdef FIXED_POINT
54    int yshift;
55 #endif
56    SAVE_STACK;
57 #ifdef FIXED_POINT
58    yshift = 14-EC_ILOG(K);
59 #endif
60    ALLOC(y, N, celt_norm_t);
61
62    /*for (i=0;i<N;i++)
63    printf ("%d ", iy[i]);*/
64    Rpp = 0;
65    for (i=0;i<N;i++)
66       Rpp = MAC16_16(Rpp,P[i],P[i]);
67
68    Ryp = 0;
69    for (i=0;i<N;i++)
70       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
71
72    /* Remove part of the pitch component to compute the real residual from
73       the encoded (int) one */
74    for (i=0;i<N;i++)
75       y[i] = SUB16(SHL16(iy[i],yshift),
76                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
77
78    /* Recompute after the projection (I think it's right) */
79    Ryp = 0;
80    for (i=0;i<N;i++)
81       Ryp = MAC16_16(Ryp,y[i],P[i]);
82
83    Ryy = 0;
84    for (i=0;i<N;i++)
85       Ryy = MAC16_16(Ryy, y[i],y[i]);
86
87    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
88    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
89
90    for (i=0;i<N;i++)
91       X[i] = P[i] + MULT16_32_Q14(y[i], g);
92    RESTORE_STACK;
93 }
94
95 /** All the info necessary to keep track of a hypothesis during the search */
96 struct NBest {
97    celt_word32_t score;
98    int sign;
99    int pos;
100    int orig;
101    celt_word32_t xy;
102    celt_word32_t yy;
103    celt_word32_t yp;
104 };
105
106 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
107 {
108    int L = 3;
109    VARDECL(celt_norm_t *_y);
110    VARDECL(celt_norm_t *_ny);
111    VARDECL(int *_iy);
112    VARDECL(int *_iny);
113    VARDECL(celt_norm_t **y);
114    VARDECL(celt_norm_t **ny);
115    VARDECL(int **iy);
116    VARDECL(int **iny);
117    int i, j, k, m;
118    int pulsesLeft;
119    VARDECL(celt_word32_t *xy);
120    VARDECL(celt_word32_t *yy);
121    VARDECL(celt_word32_t *yp);
122    VARDECL(struct NBest *_nbest);
123    VARDECL(struct NBest **nbest);
124    celt_word32_t Rpp=0, Rxp=0;
125    int maxL = 1;
126 #ifdef FIXED_POINT
127    int yshift;
128 #endif
129    SAVE_STACK;
130
131 #ifdef FIXED_POINT
132    yshift = 14-EC_ILOG(K);
133 #endif
134
135    ALLOC(_y, L*N, celt_norm_t);
136    ALLOC(_ny, L*N, celt_norm_t);
137    ALLOC(_iy, L*N, int);
138    ALLOC(_iny, L*N, int);
139    ALLOC(y, L, celt_norm_t*);
140    ALLOC(ny, L, celt_norm_t*);
141    ALLOC(iy, L, int*);
142    ALLOC(iny, L, int*);
143    
144    ALLOC(xy, L, celt_word32_t);
145    ALLOC(yy, L, celt_word32_t);
146    ALLOC(yp, L, celt_word32_t);
147    ALLOC(_nbest, L, struct NBest);
148    ALLOC(nbest, L, struct NBest *);
149    
150    for (m=0;m<L;m++)
151       nbest[m] = &_nbest[m];
152    
153    for (m=0;m<L;m++)
154    {
155       ny[m] = &_ny[m*N];
156       iny[m] = &_iny[m*N];
157       y[m] = &_y[m*N];
158       iy[m] = &_iy[m*N];
159    }
160    
161    for (j=0;j<N;j++)
162    {
163       Rpp = MAC16_16(Rpp, P[j],P[j]);
164       Rxp = MAC16_16(Rxp, X[j],P[j]);
165    }
166    Rpp = ROUND(Rpp, NORM_SHIFT);
167    Rxp = ROUND(Rxp, NORM_SHIFT);
168    if (Rpp>NORM_SCALING)
169       celt_fatal("Rpp > 1");
170
171    /* We only need to initialise the zero because the first iteration only uses that */
172    for (i=0;i<N;i++)
173       y[0][i] = 0;
174    for (i=0;i<N;i++)
175       iy[0][i] = 0;
176    xy[0] = yy[0] = yp[0] = 0;
177
178    pulsesLeft = K;
179    while (pulsesLeft > 0)
180    {
181       int pulsesAtOnce=1;
182       int Lupdate = L;
183       int L2 = L;
184       
185       /* Decide on complexity strategy */
186       pulsesAtOnce = pulsesLeft/N;
187       if (pulsesAtOnce<1)
188          pulsesAtOnce = 1;
189       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
190          Lupdate = 1;
191       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
192       L2 = Lupdate;
193       if (L2>maxL)
194       {
195          L2 = maxL;
196          maxL *= N;
197       }
198
199       for (m=0;m<Lupdate;m++)
200          nbest[m]->score = -VERY_LARGE32;
201
202       for (m=0;m<L2;m++)
203       {
204          for (j=0;j<N;j++)
205          {
206             int sign;
207             /*if (x[j]>0) sign=1; else sign=-1;*/
208             for (sign=-1;sign<=1;sign+=2)
209             {
210                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
211                celt_word32_t Rxy, Ryy, Ryp;
212                celt_word16_t spj, aspj; /* Intermediate results */
213                celt_word32_t score;
214                celt_word32_t g;
215                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
216                
217                /* All pulses at one location must have the same sign. */
218                if (iy[m][j]*sign < 0)
219                   continue;
220
221                spj = MULT16_16_P14(s, P[j]);
222                aspj = MULT16_16_P15(alpha, spj);
223                /* Updating the sums of the new pulse(s) */
224                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_P15(alpha,spj),Rxp);
225                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
226                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
227                
228                /* Compute the gain such that ||p + g*y|| = 1 */
229                g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),Rpp)) - ROUND(Ryp,14),14),ROUND(Ryy,14));
230                
231                /* Knowing that gain, what the error: (x-g*y)^2 
232                   (result is negated and we discard x^2 because it's constant) */
233                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
234                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g) -
235                      MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
236
237                if (score>nbest[Lupdate-1]->score)
238                {
239                   int k;
240                   int id = Lupdate-1;
241                   struct NBest *tmp_best;
242
243                   /* Save some pointers that would be deleted and use them for the current entry*/
244                   tmp_best = nbest[Lupdate-1];
245                   while (id > 0 && score > nbest[id-1]->score)
246                      id--;
247                
248                   for (k=Lupdate-1;k>id;k--)
249                      nbest[k] = nbest[k-1];
250
251                   nbest[id] = tmp_best;
252                   nbest[id]->score = score;
253                   nbest[id]->pos = j;
254                   nbest[id]->orig = m;
255                   nbest[id]->sign = sign;
256                   nbest[id]->xy = Rxy;
257                   nbest[id]->yy = Ryy;
258                   nbest[id]->yp = Ryp;
259                }
260             }
261          }
262
263       }
264       
265       if (!(nbest[0]->score > -VERY_LARGE32))
266          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
267       /* Only now that we've made the final choice, update ny/iny and others */
268       for (k=0;k<Lupdate;k++)
269       {
270          int n;
271          int is;
272          celt_norm_t s;
273          is = nbest[k]->sign*pulsesAtOnce;
274          s = SHL16(is, yshift);
275          for (n=0;n<N;n++)
276             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
277          ny[k][nbest[k]->pos] += s;
278
279          for (n=0;n<N;n++)
280             iny[k][n] = iy[nbest[k]->orig][n];
281          iny[k][nbest[k]->pos] += is;
282
283          xy[k] = nbest[k]->xy;
284          yy[k] = nbest[k]->yy;
285          yp[k] = nbest[k]->yp;
286       }
287       /* Swap ny/iny with y/iy */
288       for (k=0;k<Lupdate;k++)
289       {
290          celt_norm_t *tmp_ny;
291          int *tmp_iny;
292
293          tmp_ny = ny[k];
294          ny[k] = y[k];
295          y[k] = tmp_ny;
296          tmp_iny = iny[k];
297          iny[k] = iy[k];
298          iy[k] = tmp_iny;
299       }
300       pulsesLeft -= pulsesAtOnce;
301    }
302    
303 #if 0
304    if (0) {
305       celt_word32_t err=0;
306       for (i=0;i<N;i++)
307          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
308       /*if (N<=10)
309         printf ("%f %d %d\n", err, K, N);*/
310    }
311    /* Sanity checks, don't bother */
312    if (0) {
313       for (i=0;i<N;i++)
314          x[i] = p[i]+nbest[0]->gain*y[0][i];
315       celt_word32_t E=1e-15;
316       int ABS = 0;
317       for (i=0;i<N;i++)
318          ABS += abs(iy[0][i]);
319       /*if (K != ABS)
320          printf ("%d %d\n", K, ABS);*/
321       for (i=0;i<N;i++)
322          E += x[i]*x[i];
323       /*printf ("%f\n", E);*/
324       E = 1/sqrt(E);
325       for (i=0;i<N;i++)
326          x[i] *= E;
327    }
328 #endif
329    
330    encode_pulses(iy[0], N, K, enc);
331    
332    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
333       due to the recursive computation used in quantisation.
334       Not quite sure whether we need that or not */
335    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
336    RESTORE_STACK;
337 }
338
339 /** Decode pulse vector and combine the result with the pitch vector to produce
340     the final normalised signal in the current band. */
341 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
342 {
343    VARDECL(int *iy);
344    SAVE_STACK;
345    ALLOC(iy, N, int);
346    decode_pulses(iy, N, K, dec);
347    mix_pitch_and_residual(iy, X, N, K, P, alpha);
348    RESTORE_STACK;
349 }
350
351 #ifdef FIXED_POINT
352 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
353 #else
354 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
355 #endif
356
357 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
358 {
359    int i,j;
360    int best=0;
361    celt_word32_t best_score=0;
362    celt_word16_t s = 1;
363    int sign;
364    celt_word32_t E;
365    celt_word16_t pred_gain;
366    int max_pos = N0-N/B;
367    if (max_pos > 32)
368       max_pos = 32;
369
370    for (i=0;i<max_pos*B;i+=B)
371    {
372       int j;
373       celt_word32_t xy=0, yy=0;
374       celt_word32_t score;
375       for (j=0;j<N;j++)
376       {
377          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
378          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
379       }
380       score = DIV32(MULT16_16(ROUND(xy,14),ROUND(xy,14)), ROUND(yy,14));
381       if (score > best_score)
382       {
383          best_score = score;
384          best = i;
385          if (xy>0)
386             s = 1;
387          else
388             s = -1;
389       }
390    }
391    if (s<0)
392       sign = 1;
393    else
394       sign = 0;
395    /*printf ("%d %d ", sign, best);*/
396    ec_enc_uint(enc,sign,2);
397    ec_enc_uint(enc,best/B,max_pos);
398    /*printf ("%d %f\n", best, best_score);*/
399    
400    if (K>10)
401       pred_gain = pg[10];
402    else
403       pred_gain = pg[K];
404    E = EPSILON;
405    for (j=0;j<N;j++)
406    {
407       P[j] = s*Y[best+N-j-1];
408       E = MAC16_16(E, P[j],P[j]);
409    }
410    /*pred_gain = pred_gain/sqrt(E);*/
411    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
412    for (j=0;j<N;j++)
413       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
414    if (K>0)
415    {
416       for (j=0;j<N;j++)
417          x[j] -= P[j];
418    } else {
419       for (j=0;j<N;j++)
420          x[j] = P[j];
421    }
422    /*printf ("quant ");*/
423    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
424
425 }
426
427 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
428 {
429    int j;
430    int sign;
431    celt_word16_t s;
432    int best;
433    celt_word32_t E;
434    celt_word16_t pred_gain;
435    int max_pos = N0-N/B;
436    if (max_pos > 32)
437       max_pos = 32;
438    
439    sign = ec_dec_uint(dec, 2);
440    if (sign == 0)
441       s = 1;
442    else
443       s = -1;
444    
445    best = B*ec_dec_uint(dec, max_pos);
446    /*printf ("%d %d ", sign, best);*/
447
448    if (K>10)
449       pred_gain = pg[10];
450    else
451       pred_gain = pg[K];
452    E = EPSILON;
453    for (j=0;j<N;j++)
454    {
455       P[j] = s*Y[best+N-j-1];
456       E = MAC16_16(E, P[j],P[j]);
457    }
458    /*pred_gain = pred_gain/sqrt(E);*/
459    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
460    for (j=0;j<N;j++)
461       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
462    if (K==0)
463    {
464       for (j=0;j<N;j++)
465          x[j] = P[j];
466    }
467 }
468
469 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
470 {
471    int i, j;
472    celt_word32_t E;
473    celt_word16_t g;
474    
475    E = EPSILON;
476    if (N0 >= Nmax/2)
477    {
478       for (i=0;i<B;i++)
479       {
480          for (j=0;j<N/B;j++)
481          {
482             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
483             E += P[j*B+i]*P[j*B+i];
484          }
485       }
486    } else {
487       for (j=0;j<N;j++)
488       {
489          P[j] = Y[j];
490          E = MAC16_16(E, P[j],P[j]);
491       }
492    }
493    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
494    for (j=0;j<N;j++)
495       P[j] = PSHR32(MULT16_16(g, P[j]),8);
496    for (j=0;j<N;j++)
497       x[j] = P[j];
498 }
499