fixed-point: simplification of the gain in mix_pitch_and_residual()
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include <math.h>
37 #include <stdlib.h>
38 #include "mathops.h"
39 #include "cwrs.h"
40 #include "vq.h"
41 #include "arch.h"
42 #include "os_support.h"
43
44 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
45    applies the compression in the pitch direction, computes the gain that will
46    give ||p+g*y||=1 and mixes the residual with the pitch. */
47 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
48 {
49    int i;
50    celt_word32_t Ryp, Ryy, Rpp;
51    celt_word32_t g;
52    VARDECL(celt_norm_t, y);
53 #ifdef FIXED_POINT
54    int yshift;
55 #endif
56    SAVE_STACK;
57 #ifdef FIXED_POINT
58    yshift = 14-EC_ILOG(K);
59 #endif
60    ALLOC(y, N, celt_norm_t);
61
62    /*for (i=0;i<N;i++)
63    printf ("%d ", iy[i]);*/
64    Rpp = 0;
65    for (i=0;i<N;i++)
66       Rpp = MAC16_16(Rpp,P[i],P[i]);
67
68    Ryp = 0;
69    for (i=0;i<N;i++)
70       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
71
72    /* Remove part of the pitch component to compute the real residual from
73       the encoded (int) one */
74    for (i=0;i<N;i++)
75       y[i] = SUB16(SHL16(iy[i],yshift),
76                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND(Ryp,14),P[i])));
77
78    /* Recompute after the projection (I think it's right) */
79    Ryp = 0;
80    for (i=0;i<N;i++)
81       Ryp = MAC16_16(Ryp,y[i],P[i]);
82
83    Ryy = 0;
84    for (i=0;i<N;i++)
85       Ryy = MAC16_16(Ryy, y[i],y[i]);
86
87    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
88    g = MULT16_32_Q15(
89             celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
90                       MULT16_16(ROUND(Ryy,14),ROUND(Rpp,14)))
91             - ROUND(Ryp,14),
92        celt_rcp(SHR32(Ryy,9)));
93
94    for (i=0;i<N;i++)
95       X[i] = P[i] + ROUND(MULT16_16(y[i], g),11);
96    RESTORE_STACK;
97 }
98
99 /** All the info necessary to keep track of a hypothesis during the search */
100 struct NBest {
101    celt_word32_t score;
102    int sign;
103    int pos;
104    int orig;
105    celt_word32_t xy;
106    celt_word32_t yy;
107    celt_word32_t yp;
108 };
109
110 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
111 {
112    int L = 3;
113    VARDECL(celt_norm_t, _y);
114    VARDECL(celt_norm_t, _ny);
115    VARDECL(int, _iy);
116    VARDECL(int, _iny);
117    VARDECL(celt_norm_t *, y);
118    VARDECL(celt_norm_t *, ny);
119    VARDECL(int *, iy);
120    VARDECL(int *, iny);
121    int i, j, k, m;
122    int pulsesLeft;
123    VARDECL(celt_word32_t, xy);
124    VARDECL(celt_word32_t, yy);
125    VARDECL(celt_word32_t, yp);
126    VARDECL(struct NBest, _nbest);
127    VARDECL(struct NBest *, nbest);
128    celt_word32_t Rpp=0, Rxp=0;
129    int maxL = 1;
130 #ifdef FIXED_POINT
131    int yshift;
132 #endif
133    SAVE_STACK;
134
135 #ifdef FIXED_POINT
136    yshift = 14-EC_ILOG(K);
137 #endif
138
139    ALLOC(_y, L*N, celt_norm_t);
140    ALLOC(_ny, L*N, celt_norm_t);
141    ALLOC(_iy, L*N, int);
142    ALLOC(_iny, L*N, int);
143    ALLOC(y, L, celt_norm_t*);
144    ALLOC(ny, L, celt_norm_t*);
145    ALLOC(iy, L, int*);
146    ALLOC(iny, L, int*);
147    
148    ALLOC(xy, L, celt_word32_t);
149    ALLOC(yy, L, celt_word32_t);
150    ALLOC(yp, L, celt_word32_t);
151    ALLOC(_nbest, L, struct NBest);
152    ALLOC(nbest, L, struct NBest *);
153    
154    for (m=0;m<L;m++)
155       nbest[m] = &_nbest[m];
156    
157    for (m=0;m<L;m++)
158    {
159       ny[m] = &_ny[m*N];
160       iny[m] = &_iny[m*N];
161       y[m] = &_y[m*N];
162       iy[m] = &_iy[m*N];
163    }
164    
165    for (j=0;j<N;j++)
166    {
167       Rpp = MAC16_16(Rpp, P[j],P[j]);
168       Rxp = MAC16_16(Rxp, X[j],P[j]);
169    }
170    Rpp = ROUND(Rpp, NORM_SHIFT);
171    Rxp = ROUND(Rxp, NORM_SHIFT);
172    if (Rpp>NORM_SCALING)
173       celt_fatal("Rpp > 1");
174
175    /* We only need to initialise the zero because the first iteration only uses that */
176    for (i=0;i<N;i++)
177       y[0][i] = 0;
178    for (i=0;i<N;i++)
179       iy[0][i] = 0;
180    xy[0] = yy[0] = yp[0] = 0;
181
182    pulsesLeft = K;
183    while (pulsesLeft > 0)
184    {
185       int pulsesAtOnce=1;
186       int Lupdate = L;
187       int L2 = L;
188       
189       /* Decide on complexity strategy */
190       pulsesAtOnce = pulsesLeft/N;
191       if (pulsesAtOnce<1)
192          pulsesAtOnce = 1;
193       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
194          Lupdate = 1;
195       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
196       L2 = Lupdate;
197       if (L2>maxL)
198       {
199          L2 = maxL;
200          maxL *= N;
201       }
202
203       for (m=0;m<Lupdate;m++)
204          nbest[m]->score = -VERY_LARGE32;
205
206       for (m=0;m<L2;m++)
207       {
208          for (j=0;j<N;j++)
209          {
210             int sign;
211             /*if (x[j]>0) sign=1; else sign=-1;*/
212             for (sign=-1;sign<=1;sign+=2)
213             {
214                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
215                celt_word32_t Rxy, Ryy, Ryp;
216                celt_word16_t spj, aspj; /* Intermediate results */
217                celt_word32_t score;
218                celt_word32_t g;
219                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
220                
221                /* All pulses at one location must have the same sign. */
222                if (iy[m][j]*sign < 0)
223                   continue;
224
225                spj = MULT16_16_Q14(s, P[j]);
226                aspj = MULT16_16_Q15(alpha, spj);
227                /* Updating the sums of the new pulse(s) */
228                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_Q15(alpha,spj),Rxp);
229                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
230                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
231                
232                /* Compute the gain such that ||p + g*y|| = 1 */
233                g = MULT16_32_Q15(
234                         celt_sqrt(MULT16_16(ROUND(Ryp,14),ROUND(Ryp,14)) + Ryy -
235                                   MULT16_16(ROUND(Ryy,14),Rpp))
236                         - ROUND(Ryp,14),
237                    celt_rcp(SHR32(Ryy,12)));
238                /* Knowing that gain, what's the error: (x-g*y)^2 
239                   (result is negated and we discard x^2 because it's constant) */
240                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
241                score = 2*MULT16_32_Q14(ROUND(Rxy,14),g)
242                        - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND(Ryy,14),g)),g);
243
244                if (score>nbest[Lupdate-1]->score)
245                {
246                   int k;
247                   int id = Lupdate-1;
248                   struct NBest *tmp_best;
249
250                   /* Save some pointers that would be deleted and use them for the current entry*/
251                   tmp_best = nbest[Lupdate-1];
252                   while (id > 0 && score > nbest[id-1]->score)
253                      id--;
254                
255                   for (k=Lupdate-1;k>id;k--)
256                      nbest[k] = nbest[k-1];
257
258                   nbest[id] = tmp_best;
259                   nbest[id]->score = score;
260                   nbest[id]->pos = j;
261                   nbest[id]->orig = m;
262                   nbest[id]->sign = sign;
263                   nbest[id]->xy = Rxy;
264                   nbest[id]->yy = Ryy;
265                   nbest[id]->yp = Ryp;
266                }
267             }
268          }
269
270       }
271       
272       if (!(nbest[0]->score > -VERY_LARGE32))
273          celt_fatal("Could not find any match in VQ codebook. Something got corrupted somewhere.");
274       /* Only now that we've made the final choice, update ny/iny and others */
275       for (k=0;k<Lupdate;k++)
276       {
277          int n;
278          int is;
279          celt_norm_t s;
280          is = nbest[k]->sign*pulsesAtOnce;
281          s = SHL16(is, yshift);
282          for (n=0;n<N;n++)
283             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
284          ny[k][nbest[k]->pos] += s;
285
286          for (n=0;n<N;n++)
287             iny[k][n] = iy[nbest[k]->orig][n];
288          iny[k][nbest[k]->pos] += is;
289
290          xy[k] = nbest[k]->xy;
291          yy[k] = nbest[k]->yy;
292          yp[k] = nbest[k]->yp;
293       }
294       /* Swap ny/iny with y/iy */
295       for (k=0;k<Lupdate;k++)
296       {
297          celt_norm_t *tmp_ny;
298          int *tmp_iny;
299
300          tmp_ny = ny[k];
301          ny[k] = y[k];
302          y[k] = tmp_ny;
303          tmp_iny = iny[k];
304          iny[k] = iy[k];
305          iy[k] = tmp_iny;
306       }
307       pulsesLeft -= pulsesAtOnce;
308    }
309    
310 #if 0
311    if (0) {
312       celt_word32_t err=0;
313       for (i=0;i<N;i++)
314          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
315       /*if (N<=10)
316         printf ("%f %d %d\n", err, K, N);*/
317    }
318    /* Sanity checks, don't bother */
319    if (0) {
320       for (i=0;i<N;i++)
321          x[i] = p[i]+nbest[0]->gain*y[0][i];
322       celt_word32_t E=1e-15;
323       int ABS = 0;
324       for (i=0;i<N;i++)
325          ABS += abs(iy[0][i]);
326       /*if (K != ABS)
327          printf ("%d %d\n", K, ABS);*/
328       for (i=0;i<N;i++)
329          E += x[i]*x[i];
330       /*printf ("%f\n", E);*/
331       E = 1/sqrt(E);
332       for (i=0;i<N;i++)
333          x[i] *= E;
334    }
335 #endif
336    
337    encode_pulses(iy[0], N, K, enc);
338    
339    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
340       due to the recursive computation used in quantisation.
341       Not quite sure whether we need that or not */
342    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
343    RESTORE_STACK;
344 }
345
346 /** Decode pulse vector and combine the result with the pitch vector to produce
347     the final normalised signal in the current band. */
348 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
349 {
350    VARDECL(int, iy);
351    SAVE_STACK;
352    ALLOC(iy, N, int);
353    decode_pulses(iy, N, K, dec);
354    mix_pitch_and_residual(iy, X, N, K, P, alpha);
355    RESTORE_STACK;
356 }
357
358 #ifdef FIXED_POINT
359 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
360 #else
361 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
362 #endif
363
364 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
365 {
366    int i,j;
367    int best=0;
368    celt_word32_t best_score=0;
369    celt_word16_t s = 1;
370    int sign;
371    celt_word32_t E;
372    celt_word16_t pred_gain;
373    int max_pos = N0-N/B;
374    if (max_pos > 32)
375       max_pos = 32;
376
377    for (i=0;i<max_pos*B;i+=B)
378    {
379       int j;
380       celt_word32_t xy=0, yy=0;
381       celt_word32_t score;
382       for (j=0;j<N;j++)
383       {
384          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
385          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
386       }
387       score = DIV32(MULT16_16(ROUND(xy,14),ROUND(xy,14)), ROUND(yy,14));
388       if (score > best_score)
389       {
390          best_score = score;
391          best = i;
392          if (xy>0)
393             s = 1;
394          else
395             s = -1;
396       }
397    }
398    if (s<0)
399       sign = 1;
400    else
401       sign = 0;
402    /*printf ("%d %d ", sign, best);*/
403    ec_enc_uint(enc,sign,2);
404    ec_enc_uint(enc,best/B,max_pos);
405    /*printf ("%d %f\n", best, best_score);*/
406    
407    if (K>10)
408       pred_gain = pg[10];
409    else
410       pred_gain = pg[K];
411    E = EPSILON;
412    for (j=0;j<N;j++)
413    {
414       P[j] = s*Y[best+N-j-1];
415       E = MAC16_16(E, P[j],P[j]);
416    }
417    /*pred_gain = pred_gain/sqrt(E);*/
418    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
419    for (j=0;j<N;j++)
420       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
421    if (K>0)
422    {
423       for (j=0;j<N;j++)
424          x[j] -= P[j];
425    } else {
426       for (j=0;j<N;j++)
427          x[j] = P[j];
428    }
429    /*printf ("quant ");*/
430    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
431
432 }
433
434 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
435 {
436    int j;
437    int sign;
438    celt_word16_t s;
439    int best;
440    celt_word32_t E;
441    celt_word16_t pred_gain;
442    int max_pos = N0-N/B;
443    if (max_pos > 32)
444       max_pos = 32;
445    
446    sign = ec_dec_uint(dec, 2);
447    if (sign == 0)
448       s = 1;
449    else
450       s = -1;
451    
452    best = B*ec_dec_uint(dec, max_pos);
453    /*printf ("%d %d ", sign, best);*/
454
455    if (K>10)
456       pred_gain = pg[10];
457    else
458       pred_gain = pg[K];
459    E = EPSILON;
460    for (j=0;j<N;j++)
461    {
462       P[j] = s*Y[best+N-j-1];
463       E = MAC16_16(E, P[j],P[j]);
464    }
465    /*pred_gain = pred_gain/sqrt(E);*/
466    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
467    for (j=0;j<N;j++)
468       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
469    if (K==0)
470    {
471       for (j=0;j<N;j++)
472          x[j] = P[j];
473    }
474 }
475
476 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
477 {
478    int i, j;
479    celt_word32_t E;
480    celt_word16_t g;
481    
482    E = EPSILON;
483    if (N0 >= Nmax/2)
484    {
485       for (i=0;i<B;i++)
486       {
487          for (j=0;j<N/B;j++)
488          {
489             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
490             E += P[j*B+i]*P[j*B+i];
491          }
492       }
493    } else {
494       for (j=0;j<N;j++)
495       {
496          P[j] = Y[j];
497          E = MAC16_16(E, P[j],P[j]);
498       }
499    }
500    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
501    for (j=0;j<N;j++)
502       P[j] = PSHR32(MULT16_16(g, P[j]),8);
503    for (j=0;j<N;j++)
504       x[j] = P[j];
505 }
506