a few loop optimisations.
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word32_t g;
49    VARDECL(celt_norm_t, y);
50 #ifdef FIXED_POINT
51    int yshift;
52 #endif
53    SAVE_STACK;
54 #ifdef FIXED_POINT
55    yshift = 13-celt_ilog2(K);
56 #endif
57    ALLOC(y, N, celt_norm_t);
58
59    /*for (i=0;i<N;i++)
60    printf ("%d ", iy[i]);*/
61    Rpp = 0;
62    i=0;
63    do {
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65       y[i] = SHL16(iy[i],yshift);
66    } while (++i < N);
67
68    Ryp = 0;
69    Ryy = 0;
70    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
71    i=0;
72    do {
73       Ryp = MAC16_16(Ryp, y[i], P[i]);
74       Ryy = MAC16_16(Ryy, y[i], y[i]);
75    } while (++i < N);
76
77    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
78    g = MULT16_32_Q15(
79             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
80                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
81             - ROUND16(Ryp,14),
82        celt_rcp(SHR32(Ryy,9)));
83
84    i=0;
85    do 
86       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
87    while (++i < N);
88
89    RESTORE_STACK;
90 }
91
92
93 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
94 {
95    VARDECL(celt_norm_t, y);
96    VARDECL(int, iy);
97    VARDECL(int, signx);
98    int j, is;
99    celt_word16_t s;
100    int pulsesLeft;
101    celt_word32_t sum;
102    celt_word32_t xy, yy, yp;
103    celt_word16_t Rpp;
104    int N_1; /* Inverse of N, in Q14 format (even for float) */
105 #ifdef FIXED_POINT
106    int yshift;
107 #endif
108    SAVE_STACK;
109
110 #ifdef FIXED_POINT
111    yshift = 13-celt_ilog2(K);
112 #endif
113
114    ALLOC(y, N, celt_norm_t);
115    ALLOC(iy, N, int);
116    ALLOC(signx, N, int);
117    N_1 = 512/N;
118
119    sum = 0;
120    for (j=0;j<N;j++)
121    {
122       if (X[j]>0)
123          signx[j]=1;
124       else
125          signx[j]=-1;
126       iy[j] = 0;
127       y[j] = 0;
128       sum = MAC16_16(sum, P[j],P[j]);
129    }
130    Rpp = ROUND16(sum, NORM_SHIFT);
131
132    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
133
134    xy = yy = yp = 0;
135
136    pulsesLeft = K;
137    while (pulsesLeft > 0)
138    {
139       int pulsesAtOnce=1;
140       int sign;
141       celt_word32_t Rxy, Ryy, Ryp;
142       celt_word32_t g;
143       celt_word32_t best_num;
144       celt_word16_t best_den;
145       int best_id;
146       
147       /* Decide on how many pulses to find at once */
148       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
149       if (pulsesAtOnce<1)
150          pulsesAtOnce = 1;
151
152       /* This should ensure that anything we can process will have a better score */
153       best_num = -SHR32(VERY_LARGE32,4);
154       best_den = 0;
155       best_id = 0;
156       /* Choose between fast and accurate strategy depending on where we are in the search */
157       if (pulsesLeft>1)
158       {
159          /* OPT: This loop is very CPU-intensive */
160          j=0;
161          do {
162             celt_word32_t num;
163             celt_word16_t den;
164             /* Select sign based on X[j] alone */
165             sign = signx[j];
166             s = SHL16(sign*pulsesAtOnce, yshift);
167             /* Temporary sums of the new pulse(s) */
168             Rxy = xy + MULT16_16(s,X[j]);
169             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
170             
171             /* Approximate score: we maximise Rxy/sqrt(Ryy) */
172             num = MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14)));
173             den = ROUND16(Ryy,14);
174             /* The idea is to check for num/den >= best_num/best_den, but that way
175                we can do it without any division */
176             /* OPT: Make sure to use a conditional move here */
177             if (MULT16_32_Q15(best_den, num) > MULT16_32_Q15(den, best_num))
178             {
179                best_den = den;
180                best_num = num;
181                best_id = j;
182             }
183          } while (++j<N); /* Promises we loop at least once */
184       } else {
185          for (j=0;j<N;j++)
186          {
187             celt_word32_t num;
188             /* Select sign based on X[j] alone */
189             sign = signx[j];
190             s = SHL16(sign*pulsesAtOnce, yshift);
191             /* Temporary sums of the new pulse(s) */
192             Rxy = xy + MULT16_16(s,X[j]);
193             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
194             Ryp = yp + MULT16_16(s, P[j]);
195
196             /* Compute the gain such that ||p + g*y|| = 1 */
197             g = MULT16_32_Q15(
198                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
199                                MULT16_16(ROUND16(Ryy,14),Rpp))
200                      - ROUND16(Ryp,14),
201                 celt_rcp(SHR32(Ryy,12)));
202             /* Knowing that gain, what's the error: (x-g*y)^2 
203                (result is negated and we discard x^2 because it's constant) */
204             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
205             num = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
206                   - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
207             if (num >= best_num)
208             {
209                best_num = num;
210                best_id = j;
211             } 
212          }
213       }
214       
215       j = best_id;
216       is = signx[j]*pulsesAtOnce;
217       s = SHL16(is, yshift);
218
219       /* Updating the sums of the new pulse(s) */
220       xy = xy + MULT16_16(s,X[j]);
221       yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
222       yp = yp + MULT16_16(s, P[j]);
223
224       /* Only now that we've made the final choice, update y/iy */
225       y[j] += s;
226       iy[j] += is;
227       pulsesLeft -= pulsesAtOnce;
228    }
229    
230    encode_pulses(iy, N, K, enc);
231    
232    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
233    due to the recursive computation used in quantisation. */
234    mix_pitch_and_residual(iy, X, N, K, P);
235    RESTORE_STACK;
236 }
237
238
239 /** Decode pulse vector and combine the result with the pitch vector to produce
240     the final normalised signal in the current band. */
241 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
242 {
243    VARDECL(int, iy);
244    SAVE_STACK;
245    ALLOC(iy, N, int);
246    decode_pulses(iy, N, K, dec);
247    mix_pitch_and_residual(iy, X, N, K, P);
248    RESTORE_STACK;
249 }
250
251 #ifdef FIXED_POINT
252 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
253 #else
254 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
255 #endif
256
257 #define MAX_INTRA 32
258 #define LOG_MAX_INTRA 5
259       
260 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_enc *enc)
261 {
262    int i,j,c;
263    int best=0;
264    celt_word32_t best_num=-SHR32(VERY_LARGE32,4);
265    celt_word16_t best_den=0;
266    celt_word16_t s = 1;
267    int sign;
268    celt_word32_t E;
269    celt_word16_t pred_gain;
270    int max_pos = N0-N;
271    VARDECL(celt_norm_t, Xr);
272    SAVE_STACK;
273
274    ALLOC(Xr, B*N, celt_norm_t);
275    
276    if (max_pos > MAX_INTRA)
277       max_pos = MAX_INTRA;
278
279    /* Reverse the samples of x without reversing the channels */
280    for (c=0;c<B;c++)
281       for (j=0;j<N;j++)
282          Xr[B*N-B*j-B+c] = x[B*j+c];
283
284    for (i=0;i<max_pos;i++)
285    {
286       celt_word32_t xy=0, yy=0;
287       celt_word32_t num;
288       celt_word16_t den;
289       const celt_word16_t * restrict xp = Xr;
290       const celt_word16_t * restrict yp = Y+B*i;
291       /* OPT: If this doesn't generate a double-MAC (on supported architectures),
292          complain to your compilor vendor */
293       j=0;
294       do {
295          xy = MAC16_16(xy, *xp, *yp);
296          yy = MAC16_16(yy, *yp, *yp);
297          xp++;
298          yp++;
299       } while (++j<B*N); /* Promises we loop at least once */
300       /* Using xy^2/yy as the score but without having to do the division */
301       num = MULT16_16(ROUND16(xy,14),ROUND16(xy,14));
302       den = ROUND16(yy,14);
303       /* If you're really desperate for speed, just use xy as the score */
304       /* OPT: Make sure to use a conditional move here */
305       if (MULT16_32_Q15(best_den, num) >  MULT16_32_Q15(den, best_num))
306       {
307          best_num = num;
308          best_den = den;
309          best = i;
310          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
311          s = ROUND16(xy,14);
312       }
313    }
314    if (s<0)
315    {
316       s = -1;
317       sign = 1;
318    } else {
319       s = 1;
320       sign = 0;
321    }
322    /*printf ("%d %d ", sign, best);*/
323    ec_enc_bits(enc,sign,1);
324    if (max_pos == MAX_INTRA)
325       ec_enc_bits(enc,best,LOG_MAX_INTRA);
326    else
327       ec_enc_uint(enc,best,max_pos);
328
329    /*printf ("%d %f\n", best, best_score);*/
330    
331    if (K>10)
332       pred_gain = pg[10];
333    else
334       pred_gain = pg[K];
335    E = EPSILON;
336    for (c=0;c<B;c++)
337    {
338       for (j=0;j<N;j++)
339       {
340          P[B*j+c] = s*Y[B*best+B*(N-j-1)+c];
341          E = MAC16_16(E, P[j],P[j]);
342       }
343    }
344    /*pred_gain = pred_gain/sqrt(E);*/
345    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
346    for (j=0;j<B*N;j++)
347       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
348    if (K>0)
349    {
350       for (j=0;j<B*N;j++)
351          x[j] -= P[j];
352    } else {
353       for (j=0;j<B*N;j++)
354          x[j] = P[j];
355    }
356    /*printf ("quant ");*/
357    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
358    RESTORE_STACK;
359 }
360
361 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_dec *dec)
362 {
363    int j, c;
364    int sign;
365    celt_word16_t s;
366    int best;
367    celt_word32_t E;
368    celt_word16_t pred_gain;
369    int max_pos = N0-N;
370    if (max_pos > MAX_INTRA)
371       max_pos = MAX_INTRA;
372    
373    sign = ec_dec_bits(dec, 1);
374    if (sign == 0)
375       s = 1;
376    else
377       s = -1;
378    
379    if (max_pos == MAX_INTRA)
380       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
381    else
382       best = B*ec_dec_uint(dec, max_pos);
383    /*printf ("%d %d ", sign, best);*/
384
385    if (K>10)
386       pred_gain = pg[10];
387    else
388       pred_gain = pg[K];
389    E = EPSILON;
390    for (c=0;c<B;c++)
391    {
392       for (j=0;j<N;j++)
393       {
394          P[B*j+c] = s*Y[best+B*(N-j-1)+c];
395          E = MAC16_16(E, P[j],P[j]);
396       }
397    }
398    /*pred_gain = pred_gain/sqrt(E);*/
399    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
400    for (j=0;j<B*N;j++)
401       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
402    if (K==0)
403    {
404       for (j=0;j<B*N;j++)
405          x[j] = P[j];
406    }
407 }
408
409 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, int Nmax)
410 {
411    int i, j;
412    celt_word32_t E;
413    celt_word16_t g;
414    
415    E = EPSILON;
416    if (N0 >= (Nmax>>1))
417    {
418       for (i=0;i<B;i++)
419       {
420          for (j=0;j<N;j++)
421          {
422             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
423             E += P[j*B+i]*P[j*B+i];
424          }
425       }
426    } else {
427       for (j=0;j<B*N;j++)
428       {
429          P[j] = Y[j];
430          E = MAC16_16(E, P[j],P[j]);
431       }
432    }
433    g = celt_rcp(SHL32(celt_sqrt(E),9));
434    for (j=0;j<B*N;j++)
435       P[j] = PSHR32(MULT16_16(g, P[j]),8);
436    for (j=0;j<B*N;j++)
437       x[j] = P[j];
438 }
439