optimisations: caching sign of x in alg_quant(), changed celt_div()/celt_rcp()
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71    the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SHL16(iy[i],yshift);
74
75    /* Recompute after the projection (I think it's right) */
76    Ryp = 0;
77    for (i=0;i<N;i++)
78       Ryp = MAC16_16(Ryp,y[i],P[i]);
79
80    Ryy = 0;
81    for (i=0;i<N;i++)
82       Ryy = MAC16_16(Ryy, y[i],y[i]);
83
84    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
85    g = MULT16_32_Q15(
86             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
87                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
88             - ROUND16(Ryp,14),
89        celt_rcp(SHR32(Ryy,9)));
90
91    for (i=0;i<N;i++)
92       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
93    RESTORE_STACK;
94 }
95
96 /** All the info necessary to keep track of a hypothesis during the search */
97 struct NBest {
98    celt_word32_t score;
99    int sign;
100    int pos;
101    celt_word32_t xy;
102    celt_word32_t yy;
103    celt_word32_t yp;
104 };
105
106 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
107 {
108    VARDECL(celt_norm_t, _y);
109    VARDECL(celt_norm_t, _ny);
110    VARDECL(int, _iy);
111    VARDECL(int, _iny);
112    VARDECL(int, signx);
113    celt_norm_t *y, *ny;
114    int *iy, *iny;
115    int i, j;
116    int pulsesLeft;
117    celt_word32_t xy, yy, yp;
118    struct NBest nbest;
119    celt_word32_t Rpp=0, Rxp=0;
120 #ifdef FIXED_POINT
121    int yshift;
122 #endif
123    SAVE_STACK;
124
125 #ifdef FIXED_POINT
126    yshift = 14-EC_ILOG(K);
127 #endif
128
129    ALLOC(_y, N, celt_norm_t);
130    ALLOC(_ny, N, celt_norm_t);
131    ALLOC(_iy, N, int);
132    ALLOC(_iny, N, int);
133    ALLOC(signx, N, int);
134
135    y = _y;
136    ny = _ny;
137    iy = _iy;
138    iny = _iny;
139    
140    for (j=0;j<N;j++)
141    {
142       if (X[j]>0)
143          signx[j]=1;
144       else
145          signx[j]=-1;
146    }
147    
148    for (j=0;j<N;j++)
149    {
150       Rpp = MAC16_16(Rpp, P[j],P[j]);
151       Rxp = MAC16_16(Rxp, X[j],P[j]);
152    }
153    Rpp = ROUND16(Rpp, NORM_SHIFT);
154    Rxp = ROUND16(Rxp, NORM_SHIFT);
155
156    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
157
158    for (i=0;i<N;i++)
159       y[i] = 0;
160    for (i=0;i<N;i++)
161       iy[i] = 0;
162    xy = yy = yp = 0;
163
164    pulsesLeft = K;
165    while (pulsesLeft > 0)
166    {
167       int pulsesAtOnce=1;
168       
169       /* Decide on how many pulses to find at once */
170       pulsesAtOnce = pulsesLeft/N;
171       if (pulsesAtOnce<1)
172          pulsesAtOnce = 1;
173       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
174
175       nbest.score = -VERY_LARGE32;
176
177       for (j=0;j<N;j++)
178       {
179          int sign;
180          /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
181          celt_word32_t Rxy, Ryy, Ryp;
182          celt_word32_t score;
183          celt_word32_t g;
184          celt_word16_t s;
185          
186          /* Select sign based on X[j] alone */
187          sign = signx[j];
188          s = SHL16(sign*pulsesAtOnce, yshift);
189
190          /* Updating the sums of the new pulse(s) */
191          Rxy = xy + MULT16_16(s,X[j]);
192          Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
193          Ryp = yp + MULT16_16(s, P[j]);
194          
195          if (pulsesLeft>1)
196          {
197             score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
198          } else
199          {
200             /* Compute the gain such that ||p + g*y|| = 1 */
201             g = MULT16_32_Q15(
202                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
203                                MULT16_16(ROUND16(Ryy,14),Rpp))
204                      - ROUND16(Ryp,14),
205                 celt_rcp(SHR32(Ryy,12)));
206             /* Knowing that gain, what's the error: (x-g*y)^2 
207                (result is negated and we discard x^2 because it's constant) */
208             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
209             score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
210                     - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
211          }
212          
213          if (score>nbest.score)
214          {
215             nbest.score = score;
216             nbest.pos = j;
217             nbest.sign = sign;
218             nbest.xy = Rxy;
219             nbest.yy = Ryy;
220             nbest.yp = Ryp;
221          }
222       }
223
224       celt_assert2(nbest.score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
225
226       /* Only now that we've made the final choice, update ny/iny and others */
227       {
228          int n;
229          int is;
230          celt_norm_t s;
231          is = nbest.sign*pulsesAtOnce;
232          s = SHL16(is, yshift);
233          for (n=0;n<N;n++)
234             ny[n] = y[n];
235          ny[nbest.pos] += s;
236
237          for (n=0;n<N;n++)
238             iny[n] = iy[n];
239          iny[nbest.pos] += is;
240
241          xy = nbest.xy;
242          yy = nbest.yy;
243          yp = nbest.yp;
244       }
245       /* Swap ny/iny with y/iy */
246       {
247          celt_norm_t *tmp_ny;
248          int *tmp_iny;
249
250          tmp_ny = ny;
251          ny = y;
252          y = tmp_ny;
253          tmp_iny = iny;
254          iny = iy;
255          iy = tmp_iny;
256       }
257       pulsesLeft -= pulsesAtOnce;
258    }
259    
260    encode_pulses(iy, N, K, enc);
261    
262    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
263    due to the recursive computation used in quantisation. */
264    mix_pitch_and_residual(iy, X, N, K, P);
265    RESTORE_STACK;
266 }
267
268
269 /** Decode pulse vector and combine the result with the pitch vector to produce
270     the final normalised signal in the current band. */
271 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
272 {
273    VARDECL(int, iy);
274    SAVE_STACK;
275    ALLOC(iy, N, int);
276    decode_pulses(iy, N, K, dec);
277    mix_pitch_and_residual(iy, X, N, K, P);
278    RESTORE_STACK;
279 }
280
281 #ifdef FIXED_POINT
282 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
283 #else
284 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
285 #endif
286
287 #define MAX_INTRA 32
288 #define LOG_MAX_INTRA 5
289       
290 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
291 {
292    int i,j;
293    int best=0;
294    celt_word32_t best_score=0;
295    celt_word16_t s = 1;
296    int sign;
297    celt_word32_t E;
298    celt_word16_t pred_gain;
299    int max_pos = N0-N/B;
300    if (max_pos > MAX_INTRA)
301       max_pos = MAX_INTRA;
302
303    for (i=0;i<max_pos*B;i+=B)
304    {
305       celt_word32_t xy=0, yy=0;
306       celt_word32_t score;
307       /* If this doesn't generate a double-MAC on supported architectures, 
308          complain to your compilor vendor */
309       for (j=0;j<N;j++)
310       {
311          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
312          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
313       }
314       /* If you're really desperate for speed, just use xy as the score */
315       score = celt_div(MULT16_16(ROUND16(xy,14),ROUND16(xy,14)), ROUND16(yy,14));
316       if (score > best_score)
317       {
318          best_score = score;
319          best = i;
320          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
321          s = ROUND16(xy,14);
322       }
323    }
324    if (s<0)
325    {
326       s = -1;
327       sign = 1;
328    } else {
329       s = 1;
330       sign = 0;
331    }
332    /*printf ("%d %d ", sign, best);*/
333    ec_enc_bits(enc,sign,1);
334    if (max_pos == MAX_INTRA)
335       ec_enc_bits(enc,best/B,LOG_MAX_INTRA);
336    else
337       ec_enc_uint(enc,best/B,max_pos);
338
339    /*printf ("%d %f\n", best, best_score);*/
340    
341    if (K>10)
342       pred_gain = pg[10];
343    else
344       pred_gain = pg[K];
345    E = EPSILON;
346    for (j=0;j<N;j++)
347    {
348       P[j] = s*Y[best+N-j-1];
349       E = MAC16_16(E, P[j],P[j]);
350    }
351    /*pred_gain = pred_gain/sqrt(E);*/
352    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
353    for (j=0;j<N;j++)
354       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
355    if (K>0)
356    {
357       for (j=0;j<N;j++)
358          x[j] -= P[j];
359    } else {
360       for (j=0;j<N;j++)
361          x[j] = P[j];
362    }
363    /*printf ("quant ");*/
364    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
365
366 }
367
368 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
369 {
370    int j;
371    int sign;
372    celt_word16_t s;
373    int best;
374    celt_word32_t E;
375    celt_word16_t pred_gain;
376    int max_pos = N0-N/B;
377    if (max_pos > MAX_INTRA)
378       max_pos = MAX_INTRA;
379    
380    sign = ec_dec_bits(dec, 1);
381    if (sign == 0)
382       s = 1;
383    else
384       s = -1;
385    
386    if (max_pos == MAX_INTRA)
387       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
388    else
389       best = B*ec_dec_uint(dec, max_pos);
390    /*printf ("%d %d ", sign, best);*/
391
392    if (K>10)
393       pred_gain = pg[10];
394    else
395       pred_gain = pg[K];
396    E = EPSILON;
397    for (j=0;j<N;j++)
398    {
399       P[j] = s*Y[best+N-j-1];
400       E = MAC16_16(E, P[j],P[j]);
401    }
402    /*pred_gain = pred_gain/sqrt(E);*/
403    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
404    for (j=0;j<N;j++)
405       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
406    if (K==0)
407    {
408       for (j=0;j<N;j++)
409          x[j] = P[j];
410    }
411 }
412
413 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
414 {
415    int i, j;
416    celt_word32_t E;
417    celt_word16_t g;
418    
419    E = EPSILON;
420    if (N0 >= (Nmax>>1))
421    {
422       for (i=0;i<B;i++)
423       {
424          for (j=0;j<N/B;j++)
425          {
426             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
427             E += P[j*B+i]*P[j*B+i];
428          }
429       }
430    } else {
431       for (j=0;j<N;j++)
432       {
433          P[j] = Y[j];
434          E = MAC16_16(E, P[j],P[j]);
435       }
436    }
437    g = celt_rcp(SHL32(celt_sqrt(E),9));
438    for (j=0;j<N;j++)
439       P[j] = PSHR32(MULT16_16(g, P[j]),8);
440    for (j=0;j<N;j++)
441       x[j] = P[j];
442 }
443