optimisations: Another bunch of simplifications to alg_quant(), mainly to
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71    the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SHL16(iy[i],yshift);
74
75    /* Recompute after the projection (I think it's right) */
76    Ryp = 0;
77    for (i=0;i<N;i++)
78       Ryp = MAC16_16(Ryp,y[i],P[i]);
79
80    Ryy = 0;
81    for (i=0;i<N;i++)
82       Ryy = MAC16_16(Ryy, y[i],y[i]);
83
84    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
85    g = MULT16_32_Q15(
86             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
87                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
88             - ROUND16(Ryp,14),
89        celt_rcp(SHR32(Ryy,9)));
90
91    for (i=0;i<N;i++)
92       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
93    RESTORE_STACK;
94 }
95
96
97 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
98 {
99    VARDECL(celt_norm_t, y);
100    VARDECL(int, iy);
101    VARDECL(int, signx);
102    VARDECL(celt_word32_t, scores);
103    int i, j, is;
104    celt_word16_t s;
105    int pulsesLeft;
106    celt_word32_t sum;
107    celt_word32_t xy, yy, yp;
108    celt_word16_t Rpp;
109 #ifdef FIXED_POINT
110    int yshift;
111 #endif
112    SAVE_STACK;
113
114 #ifdef FIXED_POINT
115    yshift = 14-EC_ILOG(K);
116 #endif
117
118    ALLOC(y, N, celt_norm_t);
119    ALLOC(iy, N, int);
120    ALLOC(signx, N, int);
121    ALLOC(scores, N, celt_word32_t);
122
123    for (j=0;j<N;j++)
124    {
125       if (X[j]>0)
126          signx[j]=1;
127       else
128          signx[j]=-1;
129    }
130    
131    sum = 0;
132    for (j=0;j<N;j++)
133    {
134       sum = MAC16_16(sum, P[j],P[j]);
135    }
136    Rpp = ROUND16(sum, NORM_SHIFT);
137
138    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
139
140    for (i=0;i<N;i++)
141       y[i] = 0;
142    for (i=0;i<N;i++)
143       iy[i] = 0;
144    xy = yy = yp = 0;
145
146    pulsesLeft = K;
147    while (pulsesLeft > 0)
148    {
149       int pulsesAtOnce=1;
150       int sign;
151       celt_word32_t Rxy, Ryy, Ryp;
152       celt_word32_t g;
153       
154       /* Decide on how many pulses to find at once */
155       pulsesAtOnce = pulsesLeft/N;
156       if (pulsesAtOnce<1)
157          pulsesAtOnce = 1;
158       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
159
160       /* Choose between fast and accurate strategy depending on where we are in the search */
161       if (pulsesLeft>1)
162       {
163          for (j=0;j<N;j++)
164          {
165             /* Select sign based on X[j] alone */
166             sign = signx[j];
167             s = SHL16(sign*pulsesAtOnce, yshift);
168             /* Temporary sums of the new pulse(s) */
169             Rxy = xy + MULT16_16(s,X[j]);
170             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
171             Ryp = yp + MULT16_16(s, P[j]);
172             scores[j] = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
173          }
174       } else {
175          for (j=0;j<N;j++)
176          {
177             /* Select sign based on X[j] alone */
178             sign = signx[j];
179             s = SHL16(sign*pulsesAtOnce, yshift);
180             /* Temporary sums of the new pulse(s) */
181             Rxy = xy + MULT16_16(s,X[j]);
182             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
183             Ryp = yp + MULT16_16(s, P[j]);
184
185             /* Compute the gain such that ||p + g*y|| = 1 */
186             g = MULT16_32_Q15(
187                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
188                                MULT16_16(ROUND16(Ryy,14),Rpp))
189                      - ROUND16(Ryp,14),
190                 celt_rcp(SHR32(Ryy,12)));
191             /* Knowing that gain, what's the error: (x-g*y)^2 
192                (result is negated and we discard x^2 because it's constant) */
193             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
194             scores[j] = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
195                     - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
196          }
197       }
198       
199       j = find_max32(scores, N);
200       is = signx[j]*pulsesAtOnce;
201       s = SHL16(is, yshift);
202
203       /* Updating the sums of the new pulse(s) */
204       xy = xy + MULT16_16(s,X[j]);
205       yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
206       yp = yp + MULT16_16(s, P[j]);
207
208       /* Only now that we've made the final choice, update y/iy */
209       y[j] += s;
210       iy[j] += is;
211       pulsesLeft -= pulsesAtOnce;
212    }
213    
214    encode_pulses(iy, N, K, enc);
215    
216    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
217    due to the recursive computation used in quantisation. */
218    mix_pitch_and_residual(iy, X, N, K, P);
219    RESTORE_STACK;
220 }
221
222
223 /** Decode pulse vector and combine the result with the pitch vector to produce
224     the final normalised signal in the current band. */
225 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
226 {
227    VARDECL(int, iy);
228    SAVE_STACK;
229    ALLOC(iy, N, int);
230    decode_pulses(iy, N, K, dec);
231    mix_pitch_and_residual(iy, X, N, K, P);
232    RESTORE_STACK;
233 }
234
235 #ifdef FIXED_POINT
236 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
237 #else
238 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
239 #endif
240
241 #define MAX_INTRA 32
242 #define LOG_MAX_INTRA 5
243       
244 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
245 {
246    int i,j;
247    int best=0;
248    celt_word32_t best_score=0;
249    celt_word16_t s = 1;
250    int sign;
251    celt_word32_t E;
252    celt_word16_t pred_gain;
253    int max_pos = N0-N/B;
254    if (max_pos > MAX_INTRA)
255       max_pos = MAX_INTRA;
256
257    for (i=0;i<max_pos*B;i+=B)
258    {
259       celt_word32_t xy=0, yy=0;
260       celt_word32_t score;
261       /* If this doesn't generate a double-MAC on supported architectures, 
262          complain to your compilor vendor */
263       for (j=0;j<N;j++)
264       {
265          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
266          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
267       }
268       /* If you're really desperate for speed, just use xy as the score */
269       score = celt_div(MULT16_16(ROUND16(xy,14),ROUND16(xy,14)), ROUND16(yy,14));
270       if (score > best_score)
271       {
272          best_score = score;
273          best = i;
274          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
275          s = ROUND16(xy,14);
276       }
277    }
278    if (s<0)
279    {
280       s = -1;
281       sign = 1;
282    } else {
283       s = 1;
284       sign = 0;
285    }
286    /*printf ("%d %d ", sign, best);*/
287    ec_enc_bits(enc,sign,1);
288    if (max_pos == MAX_INTRA)
289       ec_enc_bits(enc,best/B,LOG_MAX_INTRA);
290    else
291       ec_enc_uint(enc,best/B,max_pos);
292
293    /*printf ("%d %f\n", best, best_score);*/
294    
295    if (K>10)
296       pred_gain = pg[10];
297    else
298       pred_gain = pg[K];
299    E = EPSILON;
300    for (j=0;j<N;j++)
301    {
302       P[j] = s*Y[best+N-j-1];
303       E = MAC16_16(E, P[j],P[j]);
304    }
305    /*pred_gain = pred_gain/sqrt(E);*/
306    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
307    for (j=0;j<N;j++)
308       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
309    if (K>0)
310    {
311       for (j=0;j<N;j++)
312          x[j] -= P[j];
313    } else {
314       for (j=0;j<N;j++)
315          x[j] = P[j];
316    }
317    /*printf ("quant ");*/
318    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
319
320 }
321
322 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
323 {
324    int j;
325    int sign;
326    celt_word16_t s;
327    int best;
328    celt_word32_t E;
329    celt_word16_t pred_gain;
330    int max_pos = N0-N/B;
331    if (max_pos > MAX_INTRA)
332       max_pos = MAX_INTRA;
333    
334    sign = ec_dec_bits(dec, 1);
335    if (sign == 0)
336       s = 1;
337    else
338       s = -1;
339    
340    if (max_pos == MAX_INTRA)
341       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
342    else
343       best = B*ec_dec_uint(dec, max_pos);
344    /*printf ("%d %d ", sign, best);*/
345
346    if (K>10)
347       pred_gain = pg[10];
348    else
349       pred_gain = pg[K];
350    E = EPSILON;
351    for (j=0;j<N;j++)
352    {
353       P[j] = s*Y[best+N-j-1];
354       E = MAC16_16(E, P[j],P[j]);
355    }
356    /*pred_gain = pred_gain/sqrt(E);*/
357    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
358    for (j=0;j<N;j++)
359       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
360    if (K==0)
361    {
362       for (j=0;j<N;j++)
363          x[j] = P[j];
364    }
365 }
366
367 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
368 {
369    int i, j;
370    celt_word32_t E;
371    celt_word16_t g;
372    
373    E = EPSILON;
374    if (N0 >= (Nmax>>1))
375    {
376       for (i=0;i<B;i++)
377       {
378          for (j=0;j<N/B;j++)
379          {
380             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
381             E += P[j*B+i]*P[j*B+i];
382          }
383       }
384    } else {
385       for (j=0;j<N;j++)
386       {
387          P[j] = Y[j];
388          E = MAC16_16(E, P[j],P[j]);
389       }
390    }
391    g = celt_rcp(SHL32(celt_sqrt(E),9));
392    for (j=0;j<N;j++)
393       P[j] = PSHR32(MULT16_16(g, P[j]),8);
394    for (j=0;j<N;j++)
395       x[j] = P[j];
396 }
397