Removed unnecessary header inclusions
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71       the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SUB16(SHL16(iy[i],yshift),
74                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
75
76    /* Recompute after the projection (I think it's right) */
77    Ryp = 0;
78    for (i=0;i<N;i++)
79       Ryp = MAC16_16(Ryp,y[i],P[i]);
80
81    Ryy = 0;
82    for (i=0;i<N;i++)
83       Ryy = MAC16_16(Ryy, y[i],y[i]);
84
85    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
86    g = MULT16_32_Q15(
87             celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
88                       MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14)))
89             - ROUND(Ryp,14),
90        celt_rcp(SHR32(Ryy,9)));
91
92    for (i=0;i<N;i++)
93       X[i] = P[i] + ROUND(MULT16_16(y[i], g),11);
94    RESTORE_STACK;
95 }
96
97 /** All the info necessary to keep track of a hypothesis during the search */
98 struct NBest {
99    celt_word32_t score;
100    int sign;
101    int pos;
102    int orig;
103    celt_word32_t xy;
104    celt_word32_t yy;
105    celt_word32_t yp;
106 };
107
108 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
109 {
110    int L = 3;
111    VARDECL(celt_norm_t, _y);
112    VARDECL(celt_norm_t, _ny);
113    VARDECL(int, _iy);
114    VARDECL(int, _iny);
115    VARDECL(celt_norm_t *, y);
116    VARDECL(celt_norm_t *, ny);
117    VARDECL(int *, iy);
118    VARDECL(int *, iny);
119    int i, j, k, m;
120    int pulsesLeft;
121    VARDECL(celt_word32_t, xy);
122    VARDECL(celt_word32_t, yy);
123    VARDECL(celt_word32_t, yp);
124    VARDECL(struct NBest, _nbest);
125    VARDECL(struct NBest *, nbest);
126    celt_word32_t Rpp=0, Rxp=0;
127    int maxL = 1;
128 #ifdef FIXED_POINT
129    int yshift;
130 #endif
131    SAVE_STACK;
132
133 #ifdef FIXED_POINT
134    yshift = 14-EC_ILOG(K);
135 #endif
136
137    ALLOC(_y, L*N, celt_norm_t);
138    ALLOC(_ny, L*N, celt_norm_t);
139    ALLOC(_iy, L*N, int);
140    ALLOC(_iny, L*N, int);
141    ALLOC(y, L, celt_norm_t*);
142    ALLOC(ny, L, celt_norm_t*);
143    ALLOC(iy, L, int*);
144    ALLOC(iny, L, int*);
145    
146    ALLOC(xy, L, celt_word32_t);
147    ALLOC(yy, L, celt_word32_t);
148    ALLOC(yp, L, celt_word32_t);
149    ALLOC(_nbest, L, struct NBest);
150    ALLOC(nbest, L, struct NBest *);
151    
152    for (m=0;m<L;m++)
153       nbest[m] = &_nbest[m];
154    
155    for (m=0;m<L;m++)
156    {
157       ny[m] = &_ny[m*N];
158       iny[m] = &_iny[m*N];
159       y[m] = &_y[m*N];
160       iy[m] = &_iy[m*N];
161    }
162    
163    for (j=0;j<N;j++)
164    {
165       Rpp = MAC16_16(Rpp, P[j],P[j]);
166       Rxp = MAC16_16(Rxp, X[j],P[j]);
167    }
168    Rpp = ROUND(Rpp, NORM_SHIFT);
169    Rxp = ROUND(Rxp, NORM_SHIFT);
170    if (Rpp>NORM_SCALING)
171       celt_fatal("Rpp > 1");
172
173    /* We only need to initialise the zero because the first iteration only uses that */
174    for (i=0;i<N;i++)
175       y[0][i] = 0;
176    for (i=0;i<N;i++)
177       iy[0][i] = 0;
178    xy[0] = yy[0] = yp[0] = 0;
179
180    pulsesLeft = K;
181    while (pulsesLeft > 0)
182    {
183       int pulsesAtOnce=1;
184       int Lupdate = L;
185       int L2 = L;
186       
187       /* Decide on complexity strategy */
188       pulsesAtOnce = pulsesLeft/N;
189       if (pulsesAtOnce<1)
190          pulsesAtOnce = 1;
191       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
192          Lupdate = 1;
193       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
194       L2 = Lupdate;
195       if (L2>maxL)
196       {
197          L2 = maxL;
198          maxL *= N;
199       }
200
201       for (m=0;m<Lupdate;m++)
202          nbest[m]->score = -VERY_LARGE32;
203
204       for (m=0;m<L2;m++)
205       {
206          for (j=0;j<N;j++)
207          {
208             int sign;
209             /*if (x[j]>0) sign=1; else sign=-1;*/
210             for (sign=-1;sign<=1;sign+=2)
211             {
212                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
213                celt_word32_t Rxy, Ryy, Ryp;
214                celt_word16_t spj, aspj; /* Intermediate results */
215                celt_word32_t score;
216                celt_word32_t g;
217                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
218                
219                /* All pulses at one location must have the same sign. */
220                if (iy[m][j]*sign < 0)
221                   continue;
222
223                spj = MULT16_16_Q14(s, P[j]);
224                aspj = MULT16_16_Q15(alpha, spj);
225                /* Updating the sums of the new pulse(s) */
226                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_Q15(alpha,spj),Rxp);
227                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
228                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
229                
230                /* Compute the gain such that ||p + g*y|| = 1 */
231                g = MULT16_32_Q15(
232                         celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
233                                   MULT16_16(ROUND(Ryy,14),Rpp))
234                         - ROUND(Ryp,14),
235                    celt_rcp(SHR32(Ryy,12)));
236                /* Knowing that gain, what's the error: (x-g*y)^2 
237                   (result is negated and we discard x^2 because it's constant) */
238                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
239                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g)
240                        - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
241
242                if (score>nbest[Lupdate-1]->score)
243                {
244                   int id = Lupdate-1;
245                   struct NBest *tmp_best;
246
247                   /* Save some pointers that would be deleted and use them for the current entry*/
248                   tmp_best = nbest[Lupdate-1];
249                   while (id > 0 && score > nbest[id-1]->score)
250                      id--;
251                
252                   for (k=Lupdate-1;k>id;k--)
253                      nbest[k] = nbest[k-1];
254
255                   nbest[id] = tmp_best;
256                   nbest[id]->score = score;
257                   nbest[id]->pos = j;
258                   nbest[id]->orig = m;
259                   nbest[id]->sign = sign;
260                   nbest[id]->xy = Rxy;
261                   nbest[id]->yy = Ryy;
262                   nbest[id]->yp = Ryp;
263                }
264             }
265          }
266
267       }
268       
269       if (!(nbest[0]->score > -VERY_LARGE32))
270          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
271       /* Only now that we've made the final choice, update ny/iny and others */
272       for (k=0;k<Lupdate;k++)
273       {
274          int n;
275          int is;
276          celt_norm_t s;
277          is = nbest[k]->sign*pulsesAtOnce;
278          s = SHL16(is, yshift);
279          for (n=0;n<N;n++)
280             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
281          ny[k][nbest[k]->pos] += s;
282
283          for (n=0;n<N;n++)
284             iny[k][n] = iy[nbest[k]->orig][n];
285          iny[k][nbest[k]->pos] += is;
286
287          xy[k] = nbest[k]->xy;
288          yy[k] = nbest[k]->yy;
289          yp[k] = nbest[k]->yp;
290       }
291       /* Swap ny/iny with y/iy */
292       for (k=0;k<Lupdate;k++)
293       {
294          celt_norm_t *tmp_ny;
295          int *tmp_iny;
296
297          tmp_ny = ny[k];
298          ny[k] = y[k];
299          y[k] = tmp_ny;
300          tmp_iny = iny[k];
301          iny[k] = iy[k];
302          iy[k] = tmp_iny;
303       }
304       pulsesLeft -= pulsesAtOnce;
305    }
306    
307 #if 0
308    if (0) {
309       celt_word32_t err=0;
310       for (i=0;i<N;i++)
311          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
312       /*if (N<=10)
313         printf ("%f %d %d\n", err, K, N);*/
314    }
315    /* Sanity checks, don't bother */
316    if (0) {
317       for (i=0;i<N;i++)
318          x[i] = p[i]+nbest[0]->gain*y[0][i];
319       celt_word32_t E=1e-15;
320       int ABS = 0;
321       for (i=0;i<N;i++)
322          ABS += abs(iy[0][i]);
323       /*if (K != ABS)
324          printf ("%d %d\n", K, ABS);*/
325       for (i=0;i<N;i++)
326          E += x[i]*x[i];
327       /*printf ("%f\n", E);*/
328       E = 1/sqrt(E);
329       for (i=0;i<N;i++)
330          x[i] *= E;
331    }
332 #endif
333    
334    encode_pulses(iy[0], N, K, enc);
335    
336    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
337       due to the recursive computation used in quantisation.
338       Not quite sure whether we need that or not */
339    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
340    RESTORE_STACK;
341 }
342
343 /** Decode pulse vector and combine the result with the pitch vector to produce
344     the final normalised signal in the current band. */
345 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
346 {
347    VARDECL(int, iy);
348    SAVE_STACK;
349    ALLOC(iy, N, int);
350    decode_pulses(iy, N, K, dec);
351    mix_pitch_and_residual(iy, X, N, K, P, alpha);
352    RESTORE_STACK;
353 }
354
355 #ifdef FIXED_POINT
356 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
357 #else
358 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
359 #endif
360
361 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
362 {
363    int i,j;
364    int best=0;
365    celt_word32_t best_score=0;
366    celt_word16_t s = 1;
367    int sign;
368    celt_word32_t E;
369    celt_word16_t pred_gain;
370    int max_pos = N0-N/B;
371    if (max_pos > 32)
372       max_pos = 32;
373
374    for (i=0;i<max_pos*B;i+=B)
375    {
376       celt_word32_t xy=0, yy=0;
377       celt_word32_t score;
378       for (j=0;j<N;j++)
379       {
380          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
381          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
382       }
383       score = DIV32(MULT16_16(ROUND(xy,14),ROUND(xy,14)), ROUND(yy,14));
384       if (score > best_score)
385       {
386          best_score = score;
387          best = i;
388          if (xy>0)
389             s = 1;
390          else
391             s = -1;
392       }
393    }
394    if (s<0)
395       sign = 1;
396    else
397       sign = 0;
398    /*printf ("%d %d ", sign, best);*/
399    ec_enc_uint(enc,sign,2);
400    ec_enc_uint(enc,best/B,max_pos);
401    /*printf ("%d %f\n", best, best_score);*/
402    
403    if (K>10)
404       pred_gain = pg[10];
405    else
406       pred_gain = pg[K];
407    E = EPSILON;
408    for (j=0;j<N;j++)
409    {
410       P[j] = s*Y[best+N-j-1];
411       E = MAC16_16(E, P[j],P[j]);
412    }
413    /*pred_gain = pred_gain/sqrt(E);*/
414    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
415    for (j=0;j<N;j++)
416       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
417    if (K>0)
418    {
419       for (j=0;j<N;j++)
420          x[j] -= P[j];
421    } else {
422       for (j=0;j<N;j++)
423          x[j] = P[j];
424    }
425    /*printf ("quant ");*/
426    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
427
428 }
429
430 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
431 {
432    int j;
433    int sign;
434    celt_word16_t s;
435    int best;
436    celt_word32_t E;
437    celt_word16_t pred_gain;
438    int max_pos = N0-N/B;
439    if (max_pos > 32)
440       max_pos = 32;
441    
442    sign = ec_dec_uint(dec, 2);
443    if (sign == 0)
444       s = 1;
445    else
446       s = -1;
447    
448    best = B*ec_dec_uint(dec, max_pos);
449    /*printf ("%d %d ", sign, best);*/
450
451    if (K>10)
452       pred_gain = pg[10];
453    else
454       pred_gain = pg[K];
455    E = EPSILON;
456    for (j=0;j<N;j++)
457    {
458       P[j] = s*Y[best+N-j-1];
459       E = MAC16_16(E, P[j],P[j]);
460    }
461    /*pred_gain = pred_gain/sqrt(E);*/
462    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
463    for (j=0;j<N;j++)
464       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
465    if (K==0)
466    {
467       for (j=0;j<N;j++)
468          x[j] = P[j];
469    }
470 }
471
472 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
473 {
474    int i, j;
475    celt_word32_t E;
476    celt_word16_t g;
477    
478    E = EPSILON;
479    if (N0 >= Nmax/2)
480    {
481       for (i=0;i<B;i++)
482       {
483          for (j=0;j<N/B;j++)
484          {
485             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
486             E += P[j*B+i]*P[j*B+i];
487          }
488       }
489    } else {
490       for (j=0;j<N;j++)
491       {
492          P[j] = Y[j];
493          E = MAC16_16(E, P[j],P[j]);
494       }
495    }
496    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
497    for (j=0;j<N;j++)
498       P[j] = PSHR32(MULT16_16(g, P[j]),8);
499    for (j=0;j<N;j++)
500       x[j] = P[j];
501 }
502