optimisation: shaving a few cycles off prev_cwrs* by not computed the values
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word32_t g;
49    VARDECL(celt_norm_t, y);
50 #ifdef FIXED_POINT
51    int yshift;
52 #endif
53    SAVE_STACK;
54 #ifdef FIXED_POINT
55    yshift = 14-EC_ILOG(K);
56 #endif
57    ALLOC(y, N, celt_norm_t);
58
59    /*for (i=0;i<N;i++)
60    printf ("%d ", iy[i]);*/
61    Rpp = 0;
62    for (i=0;i<N;i++)
63       Rpp = MAC16_16(Rpp,P[i],P[i]);
64
65    Ryp = 0;
66    for (i=0;i<N;i++)
67       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
68
69    /* Remove part of the pitch component to compute the real residual from
70    the encoded (int) one */
71    for (i=0;i<N;i++)
72       y[i] = SHL16(iy[i],yshift);
73
74    /* Recompute after the projection (I think it's right) */
75    Ryp = 0;
76    for (i=0;i<N;i++)
77       Ryp = MAC16_16(Ryp,y[i],P[i]);
78
79    Ryy = 0;
80    for (i=0;i<N;i++)
81       Ryy = MAC16_16(Ryy, y[i],y[i]);
82
83    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
84    g = MULT16_32_Q15(
85             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
86                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
87             - ROUND16(Ryp,14),
88        celt_rcp(SHR32(Ryy,9)));
89
90    for (i=0;i<N;i++)
91       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
92    RESTORE_STACK;
93 }
94
95
96 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
97 {
98    VARDECL(celt_norm_t, y);
99    VARDECL(int, iy);
100    VARDECL(int, signx);
101    int i, j, is;
102    celt_word16_t s;
103    int pulsesLeft;
104    celt_word32_t sum;
105    celt_word32_t xy, yy, yp;
106    celt_word16_t Rpp;
107 #ifdef FIXED_POINT
108    int yshift;
109 #endif
110    SAVE_STACK;
111
112 #ifdef FIXED_POINT
113    yshift = 14-EC_ILOG(K);
114 #endif
115
116    ALLOC(y, N, celt_norm_t);
117    ALLOC(iy, N, int);
118    ALLOC(signx, N, int);
119
120    for (j=0;j<N;j++)
121    {
122       if (X[j]>0)
123          signx[j]=1;
124       else
125          signx[j]=-1;
126    }
127    
128    sum = 0;
129    for (j=0;j<N;j++)
130    {
131       sum = MAC16_16(sum, P[j],P[j]);
132    }
133    Rpp = ROUND16(sum, NORM_SHIFT);
134
135    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
136
137    for (i=0;i<N;i++)
138       y[i] = 0;
139    for (i=0;i<N;i++)
140       iy[i] = 0;
141    xy = yy = yp = 0;
142
143    pulsesLeft = K;
144    while (pulsesLeft > 0)
145    {
146       int pulsesAtOnce=1;
147       int sign;
148       celt_word32_t Rxy, Ryy, Ryp;
149       celt_word32_t g;
150       celt_word32_t best_num;
151       celt_word16_t best_den;
152       int best_id;
153       
154       /* Decide on how many pulses to find at once */
155       pulsesAtOnce = pulsesLeft/N;
156       if (pulsesAtOnce<1)
157          pulsesAtOnce = 1;
158
159       /* This should ensure that anything we can process will have a better score */
160       best_num = -SHR32(VERY_LARGE32,4);
161       best_den = 0;
162       best_id = 0;
163       /* Choose between fast and accurate strategy depending on where we are in the search */
164       if (pulsesLeft>1)
165       {
166          for (j=0;j<N;j++)
167          {
168             celt_word32_t num;
169             celt_word16_t den;
170             /* Select sign based on X[j] alone */
171             sign = signx[j];
172             s = SHL16(sign*pulsesAtOnce, yshift);
173             /* Temporary sums of the new pulse(s) */
174             Rxy = xy + MULT16_16(s,X[j]);
175             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
176             
177             /* Approximate score: we maximise Rxy/sqrt(Ryy) */
178             num = MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14)));
179             den = ROUND16(Ryy,14);
180             /* The idea is to check for num/den >= best_num/best_den, but that way
181                we can do it without any division */
182             if (MULT16_32_Q15(best_den, num) > MULT16_32_Q15(den, best_num))
183             {
184                best_den = den;
185                best_num = num;
186                best_id = j;
187             }
188          }
189       } else {
190          for (j=0;j<N;j++)
191          {
192             celt_word32_t num;
193             /* Select sign based on X[j] alone */
194             sign = signx[j];
195             s = SHL16(sign*pulsesAtOnce, yshift);
196             /* Temporary sums of the new pulse(s) */
197             Rxy = xy + MULT16_16(s,X[j]);
198             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
199             Ryp = yp + MULT16_16(s, P[j]);
200
201             /* Compute the gain such that ||p + g*y|| = 1 */
202             g = MULT16_32_Q15(
203                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
204                                MULT16_16(ROUND16(Ryy,14),Rpp))
205                      - ROUND16(Ryp,14),
206                 celt_rcp(SHR32(Ryy,12)));
207             /* Knowing that gain, what's the error: (x-g*y)^2 
208                (result is negated and we discard x^2 because it's constant) */
209             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
210             num = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
211                   - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
212             if (num >= best_num)
213             {
214                best_num = num;
215                best_id = j;
216             } 
217          }
218       }
219       
220       j = best_id;
221       is = signx[j]*pulsesAtOnce;
222       s = SHL16(is, yshift);
223
224       /* Updating the sums of the new pulse(s) */
225       xy = xy + MULT16_16(s,X[j]);
226       yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
227       yp = yp + MULT16_16(s, P[j]);
228
229       /* Only now that we've made the final choice, update y/iy */
230       y[j] += s;
231       iy[j] += is;
232       pulsesLeft -= pulsesAtOnce;
233    }
234    
235    encode_pulses(iy, N, K, enc);
236    
237    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
238    due to the recursive computation used in quantisation. */
239    mix_pitch_and_residual(iy, X, N, K, P);
240    RESTORE_STACK;
241 }
242
243
244 /** Decode pulse vector and combine the result with the pitch vector to produce
245     the final normalised signal in the current band. */
246 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
247 {
248    VARDECL(int, iy);
249    SAVE_STACK;
250    ALLOC(iy, N, int);
251    decode_pulses(iy, N, K, dec);
252    mix_pitch_and_residual(iy, X, N, K, P);
253    RESTORE_STACK;
254 }
255
256 #ifdef FIXED_POINT
257 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
258 #else
259 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
260 #endif
261
262 #define MAX_INTRA 32
263 #define LOG_MAX_INTRA 5
264       
265 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_enc *enc)
266 {
267    int i,j;
268    int best=0;
269    celt_word32_t best_num=-SHR32(VERY_LARGE32,4);
270    celt_word16_t best_den=0;
271    celt_word16_t s = 1;
272    int sign;
273    celt_word32_t E;
274    celt_word16_t pred_gain;
275    int max_pos = N0-N/B;
276    if (max_pos > MAX_INTRA)
277       max_pos = MAX_INTRA;
278
279    for (i=0;i<max_pos*B;i+=B)
280    {
281       celt_word32_t xy=0, yy=0;
282       celt_word32_t num;
283       celt_word16_t den;
284       /* If this doesn't generate a double-MAC on supported architectures, 
285          complain to your compilor vendor */
286       for (j=0;j<N;j++)
287       {
288          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
289          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
290       }
291       /* Using xy^2/yy as the score but without having to do the division */
292       num = MULT16_16(ROUND16(xy,14),ROUND16(xy,14));
293       den = ROUND16(yy,14);
294       /* If you're really desperate for speed, just use xy as the score */
295       if (MULT16_32_Q15(best_den, num) >  MULT16_32_Q15(den, best_num))
296       {
297          best_num = num;
298          best_den = den;
299          best = i;
300          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
301          s = ROUND16(xy,14);
302       }
303    }
304    if (s<0)
305    {
306       s = -1;
307       sign = 1;
308    } else {
309       s = 1;
310       sign = 0;
311    }
312    /*printf ("%d %d ", sign, best);*/
313    ec_enc_bits(enc,sign,1);
314    if (max_pos == MAX_INTRA)
315       ec_enc_bits(enc,best/B,LOG_MAX_INTRA);
316    else
317       ec_enc_uint(enc,best/B,max_pos);
318
319    /*printf ("%d %f\n", best, best_score);*/
320    
321    if (K>10)
322       pred_gain = pg[10];
323    else
324       pred_gain = pg[K];
325    E = EPSILON;
326    for (j=0;j<N;j++)
327    {
328       P[j] = s*Y[best+N-j-1];
329       E = MAC16_16(E, P[j],P[j]);
330    }
331    /*pred_gain = pred_gain/sqrt(E);*/
332    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
333    for (j=0;j<N;j++)
334       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
335    if (K>0)
336    {
337       for (j=0;j<N;j++)
338          x[j] -= P[j];
339    } else {
340       for (j=0;j<N;j++)
341          x[j] = P[j];
342    }
343    /*printf ("quant ");*/
344    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
345
346 }
347
348 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_dec *dec)
349 {
350    int j;
351    int sign;
352    celt_word16_t s;
353    int best;
354    celt_word32_t E;
355    celt_word16_t pred_gain;
356    int max_pos = N0-N/B;
357    if (max_pos > MAX_INTRA)
358       max_pos = MAX_INTRA;
359    
360    sign = ec_dec_bits(dec, 1);
361    if (sign == 0)
362       s = 1;
363    else
364       s = -1;
365    
366    if (max_pos == MAX_INTRA)
367       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
368    else
369       best = B*ec_dec_uint(dec, max_pos);
370    /*printf ("%d %d ", sign, best);*/
371
372    if (K>10)
373       pred_gain = pg[10];
374    else
375       pred_gain = pg[K];
376    E = EPSILON;
377    for (j=0;j<N;j++)
378    {
379       P[j] = s*Y[best+N-j-1];
380       E = MAC16_16(E, P[j],P[j]);
381    }
382    /*pred_gain = pred_gain/sqrt(E);*/
383    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
384    for (j=0;j<N;j++)
385       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
386    if (K==0)
387    {
388       for (j=0;j<N;j++)
389          x[j] = P[j];
390    }
391 }
392
393 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, int Nmax)
394 {
395    int i, j;
396    celt_word32_t E;
397    celt_word16_t g;
398    
399    E = EPSILON;
400    if (N0 >= (Nmax>>1))
401    {
402       for (i=0;i<B;i++)
403       {
404          for (j=0;j<N/B;j++)
405          {
406             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
407             E += P[j*B+i]*P[j*B+i];
408          }
409       }
410    } else {
411       for (j=0;j<N;j++)
412       {
413          P[j] = Y[j];
414          E = MAC16_16(E, P[j],P[j]);
415       }
416    }
417    g = celt_rcp(SHL32(celt_sqrt(E),9));
418    for (j=0;j<N;j++)
419       P[j] = PSHR32(MULT16_16(g, P[j]),8);
420    for (j=0;j<N;j++)
421       x[j] = P[j];
422 }
423