optimisation: merged the init loop of vq_quant().
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word32_t g;
49    VARDECL(celt_norm_t, y);
50 #ifdef FIXED_POINT
51    int yshift;
52 #endif
53    SAVE_STACK;
54 #ifdef FIXED_POINT
55    yshift = 13-celt_ilog2(K);
56 #endif
57    ALLOC(y, N, celt_norm_t);
58
59    /*for (i=0;i<N;i++)
60    printf ("%d ", iy[i]);*/
61    Rpp = 0;
62    for (i=0;i<N;i++)
63       Rpp = MAC16_16(Rpp,P[i],P[i]);
64
65    for (i=0;i<N;i++)
66       y[i] = SHL16(iy[i],yshift);
67    
68    Ryp = 0;
69    Ryy = 0;
70    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
71    for (i=0;i<N;i++)
72    {
73       Ryp = MAC16_16(Ryp, y[i], P[i]);
74       Ryy = MAC16_16(Ryy, y[i], y[i]);
75    }
76    
77    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
78    g = MULT16_32_Q15(
79             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
80                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
81             - ROUND16(Ryp,14),
82        celt_rcp(SHR32(Ryy,9)));
83
84    for (i=0;i<N;i++)
85       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
86    RESTORE_STACK;
87 }
88
89
90 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
91 {
92    VARDECL(celt_norm_t, y);
93    VARDECL(int, iy);
94    VARDECL(int, signx);
95    int i, j, is;
96    celt_word16_t s;
97    int pulsesLeft;
98    celt_word32_t sum;
99    celt_word32_t xy, yy, yp;
100    celt_word16_t Rpp;
101    int N_1; /* Inverse of N, in Q14 format (even for float) */
102 #ifdef FIXED_POINT
103    int yshift;
104 #endif
105    SAVE_STACK;
106
107 #ifdef FIXED_POINT
108    yshift = 13-celt_ilog2(K);
109 #endif
110
111    ALLOC(y, N, celt_norm_t);
112    ALLOC(iy, N, int);
113    ALLOC(signx, N, int);
114    N_1 = 512/N;
115
116    sum = 0;
117    for (j=0;j<N;j++)
118    {
119       if (X[j]>0)
120          signx[j]=1;
121       else
122          signx[j]=-1;
123       iy[j] = 0;
124       y[j] = 0;
125       sum = MAC16_16(sum, P[j],P[j]);
126    }
127    Rpp = ROUND16(sum, NORM_SHIFT);
128
129    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
130
131    xy = yy = yp = 0;
132
133    pulsesLeft = K;
134    while (pulsesLeft > 0)
135    {
136       int pulsesAtOnce=1;
137       int sign;
138       celt_word32_t Rxy, Ryy, Ryp;
139       celt_word32_t g;
140       celt_word32_t best_num;
141       celt_word16_t best_den;
142       int best_id;
143       
144       /* Decide on how many pulses to find at once */
145       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
146       if (pulsesAtOnce<1)
147          pulsesAtOnce = 1;
148
149       /* This should ensure that anything we can process will have a better score */
150       best_num = -SHR32(VERY_LARGE32,4);
151       best_den = 0;
152       best_id = 0;
153       /* Choose between fast and accurate strategy depending on where we are in the search */
154       if (pulsesLeft>1)
155       {
156          /* OPT: This loop is very CPU-intensive */
157          j=0;
158          do {
159             celt_word32_t num;
160             celt_word16_t den;
161             /* Select sign based on X[j] alone */
162             sign = signx[j];
163             s = SHL16(sign*pulsesAtOnce, yshift);
164             /* Temporary sums of the new pulse(s) */
165             Rxy = xy + MULT16_16(s,X[j]);
166             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
167             
168             /* Approximate score: we maximise Rxy/sqrt(Ryy) */
169             num = MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14)));
170             den = ROUND16(Ryy,14);
171             /* The idea is to check for num/den >= best_num/best_den, but that way
172                we can do it without any division */
173             /* OPT: Make sure to use a conditional move here */
174             if (MULT16_32_Q15(best_den, num) > MULT16_32_Q15(den, best_num))
175             {
176                best_den = den;
177                best_num = num;
178                best_id = j;
179             }
180          } while (++j<N); /* Promises we loop at least once */
181       } else {
182          for (j=0;j<N;j++)
183          {
184             celt_word32_t num;
185             /* Select sign based on X[j] alone */
186             sign = signx[j];
187             s = SHL16(sign*pulsesAtOnce, yshift);
188             /* Temporary sums of the new pulse(s) */
189             Rxy = xy + MULT16_16(s,X[j]);
190             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
191             Ryp = yp + MULT16_16(s, P[j]);
192
193             /* Compute the gain such that ||p + g*y|| = 1 */
194             g = MULT16_32_Q15(
195                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
196                                MULT16_16(ROUND16(Ryy,14),Rpp))
197                      - ROUND16(Ryp,14),
198                 celt_rcp(SHR32(Ryy,12)));
199             /* Knowing that gain, what's the error: (x-g*y)^2 
200                (result is negated and we discard x^2 because it's constant) */
201             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
202             num = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
203                   - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
204             if (num >= best_num)
205             {
206                best_num = num;
207                best_id = j;
208             } 
209          }
210       }
211       
212       j = best_id;
213       is = signx[j]*pulsesAtOnce;
214       s = SHL16(is, yshift);
215
216       /* Updating the sums of the new pulse(s) */
217       xy = xy + MULT16_16(s,X[j]);
218       yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
219       yp = yp + MULT16_16(s, P[j]);
220
221       /* Only now that we've made the final choice, update y/iy */
222       y[j] += s;
223       iy[j] += is;
224       pulsesLeft -= pulsesAtOnce;
225    }
226    
227    encode_pulses(iy, N, K, enc);
228    
229    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
230    due to the recursive computation used in quantisation. */
231    mix_pitch_and_residual(iy, X, N, K, P);
232    RESTORE_STACK;
233 }
234
235
236 /** Decode pulse vector and combine the result with the pitch vector to produce
237     the final normalised signal in the current band. */
238 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
239 {
240    VARDECL(int, iy);
241    SAVE_STACK;
242    ALLOC(iy, N, int);
243    decode_pulses(iy, N, K, dec);
244    mix_pitch_and_residual(iy, X, N, K, P);
245    RESTORE_STACK;
246 }
247
248 #ifdef FIXED_POINT
249 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
250 #else
251 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
252 #endif
253
254 #define MAX_INTRA 32
255 #define LOG_MAX_INTRA 5
256       
257 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_enc *enc)
258 {
259    int i,j,c;
260    int best=0;
261    celt_word32_t best_num=-SHR32(VERY_LARGE32,4);
262    celt_word16_t best_den=0;
263    celt_word16_t s = 1;
264    int sign;
265    celt_word32_t E;
266    celt_word16_t pred_gain;
267    int max_pos = N0-N;
268    VARDECL(celt_norm_t, Xr);
269    SAVE_STACK;
270
271    ALLOC(Xr, B*N, celt_norm_t);
272    
273    if (max_pos > MAX_INTRA)
274       max_pos = MAX_INTRA;
275
276    /* Reverse the samples of x without reversing the channels */
277    for (c=0;c<B;c++)
278       for (j=0;j<N;j++)
279          Xr[B*N-B*j-B+c] = x[B*j+c];
280
281    for (i=0;i<max_pos;i++)
282    {
283       celt_word32_t xy=0, yy=0;
284       celt_word32_t num;
285       celt_word16_t den;
286       const celt_word16_t * restrict xp = Xr;
287       const celt_word16_t * restrict yp = Y+B*i;
288       /* OPT: If this doesn't generate a double-MAC (on supported architectures),
289          complain to your compilor vendor */
290       j=0;
291       do {
292          xy = MAC16_16(xy, *xp, *yp);
293          yy = MAC16_16(yy, *yp, *yp);
294          xp++;
295          yp++;
296       } while (++j<B*N); /* Promises we loop at least once */
297       /* Using xy^2/yy as the score but without having to do the division */
298       num = MULT16_16(ROUND16(xy,14),ROUND16(xy,14));
299       den = ROUND16(yy,14);
300       /* If you're really desperate for speed, just use xy as the score */
301       /* OPT: Make sure to use a conditional move here */
302       if (MULT16_32_Q15(best_den, num) >  MULT16_32_Q15(den, best_num))
303       {
304          best_num = num;
305          best_den = den;
306          best = i;
307          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
308          s = ROUND16(xy,14);
309       }
310    }
311    if (s<0)
312    {
313       s = -1;
314       sign = 1;
315    } else {
316       s = 1;
317       sign = 0;
318    }
319    /*printf ("%d %d ", sign, best);*/
320    ec_enc_bits(enc,sign,1);
321    if (max_pos == MAX_INTRA)
322       ec_enc_bits(enc,best,LOG_MAX_INTRA);
323    else
324       ec_enc_uint(enc,best,max_pos);
325
326    /*printf ("%d %f\n", best, best_score);*/
327    
328    if (K>10)
329       pred_gain = pg[10];
330    else
331       pred_gain = pg[K];
332    E = EPSILON;
333    for (c=0;c<B;c++)
334    {
335       for (j=0;j<N;j++)
336       {
337          P[B*j+c] = s*Y[B*best+B*(N-j-1)+c];
338          E = MAC16_16(E, P[j],P[j]);
339       }
340    }
341    /*pred_gain = pred_gain/sqrt(E);*/
342    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
343    for (j=0;j<B*N;j++)
344       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
345    if (K>0)
346    {
347       for (j=0;j<B*N;j++)
348          x[j] -= P[j];
349    } else {
350       for (j=0;j<B*N;j++)
351          x[j] = P[j];
352    }
353    /*printf ("quant ");*/
354    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
355    RESTORE_STACK;
356 }
357
358 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_dec *dec)
359 {
360    int j, c;
361    int sign;
362    celt_word16_t s;
363    int best;
364    celt_word32_t E;
365    celt_word16_t pred_gain;
366    int max_pos = N0-N;
367    if (max_pos > MAX_INTRA)
368       max_pos = MAX_INTRA;
369    
370    sign = ec_dec_bits(dec, 1);
371    if (sign == 0)
372       s = 1;
373    else
374       s = -1;
375    
376    if (max_pos == MAX_INTRA)
377       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
378    else
379       best = B*ec_dec_uint(dec, max_pos);
380    /*printf ("%d %d ", sign, best);*/
381
382    if (K>10)
383       pred_gain = pg[10];
384    else
385       pred_gain = pg[K];
386    E = EPSILON;
387    for (c=0;c<B;c++)
388    {
389       for (j=0;j<N;j++)
390       {
391          P[B*j+c] = s*Y[best+B*(N-j-1)+c];
392          E = MAC16_16(E, P[j],P[j]);
393       }
394    }
395    /*pred_gain = pred_gain/sqrt(E);*/
396    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
397    for (j=0;j<B*N;j++)
398       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
399    if (K==0)
400    {
401       for (j=0;j<B*N;j++)
402          x[j] = P[j];
403    }
404 }
405
406 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, int Nmax)
407 {
408    int i, j;
409    celt_word32_t E;
410    celt_word16_t g;
411    
412    E = EPSILON;
413    if (N0 >= (Nmax>>1))
414    {
415       for (i=0;i<B;i++)
416       {
417          for (j=0;j<N;j++)
418          {
419             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
420             E += P[j*B+i]*P[j*B+i];
421          }
422       }
423    } else {
424       for (j=0;j<B*N;j++)
425       {
426          P[j] = Y[j];
427          E = MAC16_16(E, P[j],P[j]);
428       }
429    }
430    g = celt_rcp(SHL32(celt_sqrt(E),9));
431    for (j=0;j<B*N;j++)
432       P[j] = PSHR32(MULT16_16(g, P[j]),8);
433    for (j=0;j<B*N;j++)
434       x[j] = P[j];
435 }
436