s/ROUND/ROUND16/
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P, celt_word16_t alpha)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71       the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SUB16(SHL16(iy[i],yshift),
74                    MULT16_16_Q15(alpha,MULT16_16_Q14(ROUND16(Ryp,14),P[i])));
75
76    /* Recompute after the projection (I think it's right) */
77    Ryp = 0;
78    for (i=0;i<N;i++)
79       Ryp = MAC16_16(Ryp,y[i],P[i]);
80
81    Ryy = 0;
82    for (i=0;i<N;i++)
83       Ryy = MAC16_16(Ryy, y[i],y[i]);
84
85    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
86    g = MULT16_32_Q15(
87             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
88                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
89             - ROUND16(Ryp,14),
90        celt_rcp(SHR32(Ryy,9)));
91
92    for (i=0;i<N;i++)
93       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
94    RESTORE_STACK;
95 }
96
97 /** All the info necessary to keep track of a hypothesis during the search */
98 struct NBest {
99    celt_word32_t score;
100    int sign;
101    int pos;
102    int orig;
103    celt_word32_t xy;
104    celt_word32_t yy;
105    celt_word32_t yp;
106 };
107
108 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, celt_word16_t alpha, ec_enc *enc)
109 {
110    int L = 3;
111    VARDECL(celt_norm_t, _y);
112    VARDECL(celt_norm_t, _ny);
113    VARDECL(int, _iy);
114    VARDECL(int, _iny);
115    VARDECL(celt_norm_t *, y);
116    VARDECL(celt_norm_t *, ny);
117    VARDECL(int *, iy);
118    VARDECL(int *, iny);
119    int i, j, k, m;
120    int pulsesLeft;
121    VARDECL(celt_word32_t, xy);
122    VARDECL(celt_word32_t, yy);
123    VARDECL(celt_word32_t, yp);
124    VARDECL(struct NBest, _nbest);
125    VARDECL(struct NBest *, nbest);
126    celt_word32_t Rpp=0, Rxp=0;
127    int maxL = 1;
128 #ifdef FIXED_POINT
129    int yshift;
130 #endif
131    SAVE_STACK;
132
133 #ifdef FIXED_POINT
134    yshift = 14-EC_ILOG(K);
135 #endif
136
137    ALLOC(_y, L*N, celt_norm_t);
138    ALLOC(_ny, L*N, celt_norm_t);
139    ALLOC(_iy, L*N, int);
140    ALLOC(_iny, L*N, int);
141    ALLOC(y, L, celt_norm_t*);
142    ALLOC(ny, L, celt_norm_t*);
143    ALLOC(iy, L, int*);
144    ALLOC(iny, L, int*);
145    
146    ALLOC(xy, L, celt_word32_t);
147    ALLOC(yy, L, celt_word32_t);
148    ALLOC(yp, L, celt_word32_t);
149    ALLOC(_nbest, L, struct NBest);
150    ALLOC(nbest, L, struct NBest *);
151    
152    for (m=0;m<L;m++)
153       nbest[m] = &_nbest[m];
154    
155    for (m=0;m<L;m++)
156    {
157       ny[m] = &_ny[m*N];
158       iny[m] = &_iny[m*N];
159       y[m] = &_y[m*N];
160       iy[m] = &_iy[m*N];
161    }
162    
163    for (j=0;j<N;j++)
164    {
165       Rpp = MAC16_16(Rpp, P[j],P[j]);
166       Rxp = MAC16_16(Rxp, X[j],P[j]);
167    }
168    Rpp = ROUND16(Rpp, NORM_SHIFT);
169    Rxp = ROUND16(Rxp, NORM_SHIFT);
170    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
171
172    /* We only need to initialise the zero because the first iteration only uses that */
173    for (i=0;i<N;i++)
174       y[0][i] = 0;
175    for (i=0;i<N;i++)
176       iy[0][i] = 0;
177    xy[0] = yy[0] = yp[0] = 0;
178
179    pulsesLeft = K;
180    while (pulsesLeft > 0)
181    {
182       int pulsesAtOnce=1;
183       int Lupdate = L;
184       int L2 = L;
185       
186       /* Decide on complexity strategy */
187       pulsesAtOnce = pulsesLeft/N;
188       if (pulsesAtOnce<1)
189          pulsesAtOnce = 1;
190       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
191          Lupdate = 1;
192       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
193       L2 = Lupdate;
194       if (L2>maxL)
195       {
196          L2 = maxL;
197          maxL *= N;
198       }
199
200       for (m=0;m<Lupdate;m++)
201          nbest[m]->score = -VERY_LARGE32;
202
203       for (m=0;m<L2;m++)
204       {
205          for (j=0;j<N;j++)
206          {
207             int sign;
208             /*if (x[j]>0) sign=1; else sign=-1;*/
209             for (sign=-1;sign<=1;sign+=2)
210             {
211                /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
212                celt_word32_t Rxy, Ryy, Ryp;
213                celt_word16_t spj, aspj; /* Intermediate results */
214                celt_word32_t score;
215                celt_word32_t g;
216                celt_word16_t s = SHL16(sign*pulsesAtOnce, yshift);
217                
218                /* All pulses at one location must have the same sign. */
219                if (iy[m][j]*sign < 0)
220                   continue;
221
222                spj = MULT16_16_Q14(s, P[j]);
223                aspj = MULT16_16_Q15(alpha, spj);
224                /* Updating the sums of the new pulse(s) */
225                Rxy = xy[m] + MULT16_16(s,X[j])     - MULT16_16(MULT16_16_Q15(alpha,spj),Rxp);
226                Ryy = yy[m] + 2*MULT16_16(s,y[m][j]) + MULT16_16(s,s)   +MULT16_16(aspj,MULT16_16_Q14(aspj,Rpp)) - 2*MULT16_32_Q14(aspj,yp[m]) - 2*MULT16_16(s,MULT16_16_Q14(aspj,P[j]));
227                Ryp = yp[m] + MULT16_16(spj, SUB16(QCONST16(1.f,14),MULT16_16_Q15(alpha,Rpp)));
228                
229                /* Compute the gain such that ||p + g*y|| = 1 */
230                g = MULT16_32_Q15(
231                         celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
232                                   MULT16_16(ROUND16(Ryy,14),Rpp))
233                         - ROUND16(Ryp,14),
234                    celt_rcp(SHR32(Ryy,12)));
235                /* Knowing that gain, what's the error: (x-g*y)^2 
236                   (result is negated and we discard x^2 because it's constant) */
237                /*score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
238                score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
239                        - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
240
241                if (score>nbest[Lupdate-1]->score)
242                {
243                   int id = Lupdate-1;
244                   struct NBest *tmp_best;
245
246                   /* Save some pointers that would be deleted and use them for the current entry*/
247                   tmp_best = nbest[Lupdate-1];
248                   while (id > 0 && score > nbest[id-1]->score)
249                      id--;
250                
251                   for (k=Lupdate-1;k>id;k--)
252                      nbest[k] = nbest[k-1];
253
254                   nbest[id] = tmp_best;
255                   nbest[id]->score = score;
256                   nbest[id]->pos = j;
257                   nbest[id]->orig = m;
258                   nbest[id]->sign = sign;
259                   nbest[id]->xy = Rxy;
260                   nbest[id]->yy = Ryy;
261                   nbest[id]->yp = Ryp;
262                }
263             }
264          }
265
266       }
267       
268       celt_assert2(nbest[0]->score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
269       /* Only now that we've made the final choice, update ny/iny and others */
270       for (k=0;k<Lupdate;k++)
271       {
272          int n;
273          int is;
274          celt_norm_t s;
275          is = nbest[k]->sign*pulsesAtOnce;
276          s = SHL16(is, yshift);
277          for (n=0;n<N;n++)
278             ny[k][n] = y[nbest[k]->orig][n] - MULT16_16_Q15(alpha,MULT16_16_Q14(s,MULT16_16_Q14(P[nbest[k]->pos],P[n])));
279          ny[k][nbest[k]->pos] += s;
280
281          for (n=0;n<N;n++)
282             iny[k][n] = iy[nbest[k]->orig][n];
283          iny[k][nbest[k]->pos] += is;
284
285          xy[k] = nbest[k]->xy;
286          yy[k] = nbest[k]->yy;
287          yp[k] = nbest[k]->yp;
288       }
289       /* Swap ny/iny with y/iy */
290       for (k=0;k<Lupdate;k++)
291       {
292          celt_norm_t *tmp_ny;
293          int *tmp_iny;
294
295          tmp_ny = ny[k];
296          ny[k] = y[k];
297          y[k] = tmp_ny;
298          tmp_iny = iny[k];
299          iny[k] = iy[k];
300          iy[k] = tmp_iny;
301       }
302       pulsesLeft -= pulsesAtOnce;
303    }
304    
305 #if 0
306    if (0) {
307       celt_word32_t err=0;
308       for (i=0;i<N;i++)
309          err += (x[i]-nbest[0]->gain*y[0][i])*(x[i]-nbest[0]->gain*y[0][i]);
310       /*if (N<=10)
311         printf ("%f %d %d\n", err, K, N);*/
312    }
313    /* Sanity checks, don't bother */
314    if (0) {
315       for (i=0;i<N;i++)
316          x[i] = p[i]+nbest[0]->gain*y[0][i];
317       celt_word32_t E=1e-15;
318       int ABS = 0;
319       for (i=0;i<N;i++)
320          ABS += abs(iy[0][i]);
321       /*if (K != ABS)
322          printf ("%d %d\n", K, ABS);*/
323       for (i=0;i<N;i++)
324          E += x[i]*x[i];
325       /*printf ("%f\n", E);*/
326       E = 1/sqrt(E);
327       for (i=0;i<N;i++)
328          x[i] *= E;
329    }
330 #endif
331    
332    encode_pulses(iy[0], N, K, enc);
333    
334    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
335       due to the recursive computation used in quantisation.
336       Not quite sure whether we need that or not */
337    mix_pitch_and_residual(iy[0], X, N, K, P, alpha);
338    RESTORE_STACK;
339 }
340
341 /** Decode pulse vector and combine the result with the pitch vector to produce
342     the final normalised signal in the current band. */
343 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, celt_word16_t alpha, ec_dec *dec)
344 {
345    VARDECL(int, iy);
346    SAVE_STACK;
347    ALLOC(iy, N, int);
348    decode_pulses(iy, N, K, dec);
349    mix_pitch_and_residual(iy, X, N, K, P, alpha);
350    RESTORE_STACK;
351 }
352
353 #ifdef FIXED_POINT
354 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
355 #else
356 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
357 #endif
358
359 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
360 {
361    int i,j;
362    int best=0;
363    celt_word32_t best_score=0;
364    celt_word16_t s = 1;
365    int sign;
366    celt_word32_t E;
367    celt_word16_t pred_gain;
368    int max_pos = N0-N/B;
369    if (max_pos > 32)
370       max_pos = 32;
371
372    for (i=0;i<max_pos*B;i+=B)
373    {
374       celt_word32_t xy=0, yy=0;
375       celt_word32_t score;
376       for (j=0;j<N;j++)
377       {
378          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
379          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
380       }
381       score = DIV32(MULT16_16(ROUND16(xy,14),ROUND16(xy,14)), ROUND16(yy,14));
382       if (score > best_score)
383       {
384          best_score = score;
385          best = i;
386          if (xy>0)
387             s = 1;
388          else
389             s = -1;
390       }
391    }
392    if (s<0)
393       sign = 1;
394    else
395       sign = 0;
396    /*printf ("%d %d ", sign, best);*/
397    ec_enc_uint(enc,sign,2);
398    ec_enc_uint(enc,best/B,max_pos);
399    /*printf ("%d %f\n", best, best_score);*/
400    
401    if (K>10)
402       pred_gain = pg[10];
403    else
404       pred_gain = pg[K];
405    E = EPSILON;
406    for (j=0;j<N;j++)
407    {
408       P[j] = s*Y[best+N-j-1];
409       E = MAC16_16(E, P[j],P[j]);
410    }
411    /*pred_gain = pred_gain/sqrt(E);*/
412    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
413    for (j=0;j<N;j++)
414       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
415    if (K>0)
416    {
417       for (j=0;j<N;j++)
418          x[j] -= P[j];
419    } else {
420       for (j=0;j<N;j++)
421          x[j] = P[j];
422    }
423    /*printf ("quant ");*/
424    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
425
426 }
427
428 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
429 {
430    int j;
431    int sign;
432    celt_word16_t s;
433    int best;
434    celt_word32_t E;
435    celt_word16_t pred_gain;
436    int max_pos = N0-N/B;
437    if (max_pos > 32)
438       max_pos = 32;
439    
440    sign = ec_dec_uint(dec, 2);
441    if (sign == 0)
442       s = 1;
443    else
444       s = -1;
445    
446    best = B*ec_dec_uint(dec, max_pos);
447    /*printf ("%d %d ", sign, best);*/
448
449    if (K>10)
450       pred_gain = pg[10];
451    else
452       pred_gain = pg[K];
453    E = EPSILON;
454    for (j=0;j<N;j++)
455    {
456       P[j] = s*Y[best+N-j-1];
457       E = MAC16_16(E, P[j],P[j]);
458    }
459    /*pred_gain = pred_gain/sqrt(E);*/
460    pred_gain = MULT16_16_Q15(pred_gain,DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E)));
461    for (j=0;j<N;j++)
462       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
463    if (K==0)
464    {
465       for (j=0;j<N;j++)
466          x[j] = P[j];
467    }
468 }
469
470 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
471 {
472    int i, j;
473    celt_word32_t E;
474    celt_word16_t g;
475    
476    E = EPSILON;
477    if (N0 >= Nmax/2)
478    {
479       for (i=0;i<B;i++)
480       {
481          for (j=0;j<N/B;j++)
482          {
483             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
484             E += P[j*B+i]*P[j*B+i];
485          }
486       }
487    } else {
488       for (j=0;j<N;j++)
489       {
490          P[j] = Y[j];
491          E = MAC16_16(E, P[j],P[j]);
492       }
493    }
494    g = DIV32_16(SHL32(EXTEND32(1),14+8),celt_sqrt(E));
495    for (j=0;j<N;j++)
496       P[j] = PSHR32(MULT16_16(g, P[j]),8);
497    for (j=0;j<N;j++)
498       x[j] = P[j];
499 }
500