7c89d02df519bfebb9d997745c80a2a0af79b5d8
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
45    applies the compression in the pitch direction, computes the gain that will
46    give ||p+g*y||=1 and mixes the residual with the pitch. */
47 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
48 {
49    int i;
50    celt_word32_t Ryp, Ryy, Rpp;
51    celt_word32_t g;
52    VARDECL(celt_norm_t, y);
53 #ifdef FIXED_POINT
54    int yshift;
55 #endif
56    SAVE_STACK;
57 #ifdef FIXED_POINT
58    yshift = 14-EC_ILOG(K);
59 #endif
60    ALLOC(y, N, celt_norm_t);
61
62    /*for (i=0;i<N;i++)
63    printf ("%d ", iy[i]);*/
64    Rpp = 0;
65    for (i=0;i<N;i++)
66       Rpp = MAC16_16(Rpp,P[i],P[i]);
67
68    Ryp = 0;
69    for (i=0;i<N;i++)
70       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
71
72    /* Remove part of the pitch component to compute the real residual from
73       the encoded (int) one */
74    for (i=0;i<N;i++)
75       y[i] = SUB16(SHL16(iy[i],yshift),
76                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
77
78    /* Recompute after the projection (I think it's right) */
79    Ryp = 0;
80    for (i=0;i<N;i++)
81       Ryp = MAC16_16(Ryp,y[i],P[i]);
82
83    Ryy = 0;
84    for (i=0;i<N;i++)
85       Ryy = MAC16_16(Ryy, y[i],y[i]);
86
87    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
88    g = DIV32(SHL32(celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy - MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14))) - ROUND(Ryp,14),14),ROUND(Ryy,14));
89
90    for (i=0;i<N;i++)
91       X[i] = P[i] + MULT16_32_Q14(y[i], g);
92    RESTORE_STACK;
93 }
94
95 /** All the info necessary to keep track of a hypothesis during the search */
96 struct NBest {
97    celt_word32_t score;
98    int sign;
99    int pos;
100    int orig;
101    celt_word32_t xy;
102    celt_word32_t yy;
103    celt_word32_t yp;
104 };
105
106 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
107 {
108    int L = 3;
109    VARDECL(celt_norm_t, _y);
110    VARDECL(celt_norm_t, _ny);
111    VARDECL(int, _iy);
112    VARDECL(int, _iny);
113    VARDECL(celt_norm_t *, y);
114    VARDECL(celt_norm_t *, ny);
115    VARDECL(int *, iy);
116    VARDECL(int *, iny);
117    int i, j, k, m;
118    int pulsesLeft;
119    VARDECL(celt_word32_t, xy);
120    VARDECL(celt_word32_t, yy);
121    VARDECL(celt_word32_t, yp);
122    VARDECL(struct NBest, _nbest);
123    VARDECL(struct NBest *, nbest);
124    celt_word32_t Rpp=0, Rxp=0;
125    int maxL = 1;
126 #ifdef FIXED_POINT
127    int yshift;
128 #endif
129    SAVE_STACK;
130
131 #ifdef FIXED_POINT
132    yshift = 14-EC_ILOG(K);
133 #endif
134
135    ALLOC(_y, L*N, celt_norm_t);
136    ALLOC(_ny, L*N, celt_norm_t);
137    ALLOC(_iy, L*N, int);
138    ALLOC(_iny, L*N, int);
139    ALLOC(y, L, celt_norm_t*);
140    ALLOC(ny, L, celt_norm_t*);
141    ALLOC(iy, L, int*);
142    ALLOC(iny, L, int*);
143    
144    ALLOC(xy, L, celt_word32_t);
145    ALLOC(yy, L, celt_word32_t);
146    ALLOC(yp, L, celt_word32_t);
147    ALLOC(_nbest, L, struct NBest);
148    ALLOC(nbest, L, struct NBest *);
149    
150    for (m=0;m<L;m++)
151       nbest[m] = &_nbest[m];
152    
153    for (m=0;m<L;m++)
154    {
155       ny[m] = &_ny[m*N];
156       iny[m] = &_iny[m*N];
157       y[m] = &_y[m*N];
158       iy[m] = &_iy[m*N];
159    }
160    
161    for (j=0;j<N;j++)
162    {
163       Rpp = MAC16_16(Rpp, P[j],P[j]);
164       Rxp = MAC16_16(Rxp, X[j],P[j]);
165    }
166    Rpp = ROUND(Rpp, NORM_SHIFT);
167    Rxp = ROUND(Rxp, NORM_SHIFT);
168    if (Rpp>NORM_SCALING)
169       celt_fatal("Rpp > 1");
170
171    /* We only need to initialise the zero because the first iteration only uses that */
172    for (i=0;i<N;i++)
173       y[0][i] = 0;
174    for (i=0;i<N;i++)
175       iy[0][i] = 0;
176    xy[0] = yy[0] = yp[0] = 0;
177
178    pulsesLeft = K;
179    while (pulsesLeft > 0)
180    {
181       int pulsesAtOnce=1;
182       int Lupdate = L;
183       int L2 = L;
184       
185       /* Decide on complexity strategy */
186       pulsesAtOnce = pulsesLeft/N;
187       if (pulsesAtOnce<1)
188          pulsesAtOnce = 1;
189       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
190          Lupdate = 1;
191       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
192       L2 = Lupdate;
193       if (L2>maxL)
194       {
195          L2 = maxL;
196          maxL *= N;
197       }
198
199       for (m=0;m<Lupdate;m++)
200          nbest[m]->score = -VERY_LARGE32;
201
202       for (m=0;m<L2;m++)
203       {
204          for (j=0;j<N;j++)
205          {
206             int sign;
207             /*if (x[j]>0) sign=1; else sign=-1;*/
208             for (sign=-1;sign<=1;sign+=2)
209             {
210                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
211                celt_word32_t Rxy, Ryy, Ryp;
212                celt_word16_t spj, aspj; /* Intermediate results */
213                celt_word32_t score;
214                celt_word32_t g;
215                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
216                
217                /* All pulses at one location must have the same sign. */
218                if (iy[m][j]*sign < 0)
219                   continue;
220
221                spj = MULT16_16_Q14(s, P[j]);
222                aspj = MULT16_16_Q15(alpha, spj);
223                /* Updating the sums of the new pulse(s) */
224                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_Q15(alpha,spj),Rxp);
225                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
226                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
227                
228                /* Compute the gain such that ||p + g*y|| = 1 */
229                g = MULT16_32_Q15(
230                         celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
231                                   MULT16_16(ROUND(Ryy,14),Rpp))
232                         - ROUND(Ryp,14),
233                    celt_rcp(SHR32(Ryy,12)));
234                /* Knowing that gain, what's the error: (x-g*y)^2 
235                   (result is negated and we discard x^2 because it's constant) */
236                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
237                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g)
238                        - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
239
240                if (score>nbest[Lupdate-1]->score)
241                {
242                   int k;
243                   int id = Lupdate-1;
244                   struct NBest *tmp_best;
245
246                   /* Save some pointers that would be deleted and use them for the current entry*/
247                   tmp_best = nbest[Lupdate-1];
248                   while (id > 0 && score > nbest[id-1]->score)
249                      id--;
250                
251                   for (k=Lupdate-1;k>id;k--)
252                      nbest[k] = nbest[k-1];
253
254                   nbest[id] = tmp_best;
255                   nbest[id]->score = score;
256                   nbest[id]->pos = j;
257                   nbest[id]->orig = m;
258                   nbest[id]->sign = sign;
259                   nbest[id]->xy = Rxy;
260                   nbest[id]->yy = Ryy;
261                   nbest[id]->yp = Ryp;
262                }
263             }
264          }
265
266       }
267       
268       if (!(nbest[0]->score > -VERY_LARGE32))
269          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
270       /* Only now that we've made the final choice, update ny/iny and others */
271       for (k=0;k<Lupdate;k++)
272       {
273          int n;
274          int is;
275          celt_norm_t s;
276          is = nbest[k]->sign*pulsesAtOnce;
277          s = SHL16(is, yshift);
278          for (n=0;n<N;n++)
279             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
280          ny[k][nbest[k]->pos] += s;
281
282          for (n=0;n<N;n++)
283             iny[k][n] = iy[nbest[k]->orig][n];
284          iny[k][nbest[k]->pos] += is;
285
286          xy[k] = nbest[k]->xy;
287          yy[k] = nbest[k]->yy;
288          yp[k] = nbest[k]->yp;
289       }
290       /* Swap ny/iny with y/iy */
291       for (k=0;k<Lupdate;k++)
292       {
293          celt_norm_t *tmp_ny;
294          int *tmp_iny;
295
296          tmp_ny = ny[k];
297          ny[k] = y[k];
298          y[k] = tmp_ny;
299          tmp_iny = iny[k];
300          iny[k] = iy[k];
301          iy[k] = tmp_iny;
302       }
303       pulsesLeft -= pulsesAtOnce;
304    }
305    
306 #if 0
307    if (0) {
308       celt_word32_t err=0;
309       for (i=0;i<N;i++)
310          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
311       /*if (N<=10)
312         printf ("%f %d %d\n", err, K, N);*/
313    }
314    /* Sanity checks, don't bother */
315    if (0) {
316       for (i=0;i<N;i++)
317          x[i] = p[i]+nbest[0]->gain*y[0][i];
318       celt_word32_t E=1e-15;
319       int ABS = 0;
320       for (i=0;i<N;i++)
321          ABS += abs(iy[0][i]);
322       /*if (K != ABS)
323          printf ("%d %d\n", K, ABS);*/
324       for (i=0;i<N;i++)
325          E += x[i]*x[i];
326       /*printf ("%f\n", E);*/
327       E = 1/sqrt(E);
328       for (i=0;i<N;i++)
329          x[i] *= E;
330    }
331 #endif
332    
333    encode_pulses(iy[0], N, K, enc);
334    
335    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
336       due to the recursive computation used in quantisation.
337       Not quite sure whether we need that or not */
338    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
339    RESTORE_STACK;
340 }
341
342 /** Decode pulse vector and combine the result with the pitch vector to produce
343     the final normalised signal in the current band. */
344 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
345 {
346    VARDECL(int, iy);
347    SAVE_STACK;
348    ALLOC(iy, N, int);
349    decode_pulses(iy, N, K, dec);
350    mix_pitch_and_residual(iy, X, N, K, P, alpha);
351    RESTORE_STACK;
352 }
353
354 #ifdef FIXED_POINT
355 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
356 #else
357 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
358 #endif
359
360 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
361 {
362    int i,j;
363    int best=0;
364    celt_word32_t best_score=0;
365    celt_word16_t s = 1;
366    int sign;
367    celt_word32_t E;
368    celt_word16_t pred_gain;
369    int max_pos = N0-N/B;
370    if (max_pos > 32)
371       max_pos = 32;
372
373    for (i=0;i<max_pos*B;i+=B)
374    {
375       int j;
376       celt_word32_t xy=0, yy=0;
377       celt_word32_t score;
378       for (j=0;j<N;j++)
379       {
380          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
381          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
382       }
383       score = DIV32(MULT16_16(ROUND(xy,14),ROUND(xy,14)), ROUND(yy,14));
384       if (score > best_score)
385       {
386          best_score = score;
387          best = i;
388          if (xy>0)
389             s = 1;
390          else
391             s = -1;
392       }
393    }
394    if (s<0)
395       sign = 1;
396    else
397       sign = 0;
398    /*printf ("%d %d ", sign, best);*/
399    ec_enc_uint(enc,sign,2);
400    ec_enc_uint(enc,best/B,max_pos);
401    /*printf ("%d %f\n", best, best_score);*/
402    
403    if (K>10)
404       pred_gain = pg[10];
405    else
406       pred_gain = pg[K];
407    E = EPSILON;
408    for (j=0;j<N;j++)
409    {
410       P[j] = s*Y[best+N-j-1];
411       E = MAC16_16(E, P[j],P[j]);
412    }
413    /*pred_gain = pred_gain/sqrt(E);*/
414    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
415    for (j=0;j<N;j++)
416       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
417    if (K>0)
418    {
419       for (j=0;j<N;j++)
420          x[j] -= P[j];
421    } else {
422       for (j=0;j<N;j++)
423          x[j] = P[j];
424    }
425    /*printf ("quant ");*/
426    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
427
428 }
429
430 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
431 {
432    int j;
433    int sign;
434    celt_word16_t s;
435    int best;
436    celt_word32_t E;
437    celt_word16_t pred_gain;
438    int max_pos = N0-N/B;
439    if (max_pos > 32)
440       max_pos = 32;
441    
442    sign = ec_dec_uint(dec, 2);
443    if (sign == 0)
444       s = 1;
445    else
446       s = -1;
447    
448    best = B*ec_dec_uint(dec, max_pos);
449    /*printf ("%d %d ", sign, best);*/
450
451    if (K>10)
452       pred_gain = pg[10];
453    else
454       pred_gain = pg[K];
455    E = EPSILON;
456    for (j=0;j<N;j++)
457    {
458       P[j] = s*Y[best+N-j-1];
459       E = MAC16_16(E, P[j],P[j]);
460    }
461    /*pred_gain = pred_gain/sqrt(E);*/
462    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
463    for (j=0;j<N;j++)
464       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
465    if (K==0)
466    {
467       for (j=0;j<N;j++)
468          x[j] = P[j];
469    }
470 }
471
472 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
473 {
474    int i, j;
475    celt_word32_t E;
476    celt_word16_t g;
477    
478    E = EPSILON;
479    if (N0 >= Nmax/2)
480    {
481       for (i=0;i<B;i++)
482       {
483          for (j=0;j<N/B;j++)
484          {
485             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
486             E += P[j*B+i]*P[j*B+i];
487          }
488       }
489    } else {
490       for (j=0;j<N;j++)
491       {
492          P[j] = Y[j];
493          E = MAC16_16(E, P[j],P[j]);
494       }
495    }
496    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
497    for (j=0;j<N;j++)
498       P[j] = PSHR32(MULT16_16(g, P[j]),8);
499    for (j=0;j<N;j++)
500       x[j] = P[j];
501 }
502