A bunch of pointers marked as "restrict" to ease the job of the compiler
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71    the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SHL16(iy[i],yshift);
74
75    /* Recompute after the projection (I think it's right) */
76    Ryp = 0;
77    for (i=0;i<N;i++)
78       Ryp = MAC16_16(Ryp,y[i],P[i]);
79
80    Ryy = 0;
81    for (i=0;i<N;i++)
82       Ryy = MAC16_16(Ryy, y[i],y[i]);
83
84    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
85    g = MULT16_32_Q15(
86             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
87                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
88             - ROUND16(Ryp,14),
89        celt_rcp(SHR32(Ryy,9)));
90
91    for (i=0;i<N;i++)
92       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
93    RESTORE_STACK;
94 }
95
96
97 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
98 {
99    VARDECL(celt_norm_t, y);
100    VARDECL(int, iy);
101    VARDECL(int, signx);
102    VARDECL(celt_word32_t, scores);
103    int i, j, is;
104    celt_word16_t s;
105    int pulsesLeft;
106    celt_word32_t sum;
107    celt_word32_t xy, yy, yp;
108    celt_word16_t Rpp;
109 #ifdef FIXED_POINT
110    int yshift;
111 #endif
112    SAVE_STACK;
113
114 #ifdef FIXED_POINT
115    yshift = 14-EC_ILOG(K);
116 #endif
117
118    ALLOC(y, N, celt_norm_t);
119    ALLOC(iy, N, int);
120    ALLOC(signx, N, int);
121    ALLOC(scores, N, celt_word32_t);
122
123    for (j=0;j<N;j++)
124    {
125       if (X[j]>0)
126          signx[j]=1;
127       else
128          signx[j]=-1;
129    }
130    
131    sum = 0;
132    for (j=0;j<N;j++)
133    {
134       sum = MAC16_16(sum, P[j],P[j]);
135    }
136    Rpp = ROUND16(sum, NORM_SHIFT);
137
138    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
139
140    for (i=0;i<N;i++)
141       y[i] = 0;
142    for (i=0;i<N;i++)
143       iy[i] = 0;
144    xy = yy = yp = 0;
145
146    pulsesLeft = K;
147    while (pulsesLeft > 0)
148    {
149       int pulsesAtOnce=1;
150       int sign;
151       celt_word32_t Rxy, Ryy, Ryp;
152       celt_word32_t g;
153       
154       /* Decide on how many pulses to find at once */
155       pulsesAtOnce = pulsesLeft/N;
156       if (pulsesAtOnce<1)
157          pulsesAtOnce = 1;
158
159       /* Choose between fast and accurate strategy depending on where we are in the search */
160       if (pulsesLeft>1)
161       {
162          for (j=0;j<N;j++)
163          {
164             /* Select sign based on X[j] alone */
165             sign = signx[j];
166             s = SHL16(sign*pulsesAtOnce, yshift);
167             /* Temporary sums of the new pulse(s) */
168             Rxy = xy + MULT16_16(s,X[j]);
169             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
170             Ryp = yp + MULT16_16(s, P[j]);
171             scores[j] = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
172          }
173       } else {
174          for (j=0;j<N;j++)
175          {
176             /* Select sign based on X[j] alone */
177             sign = signx[j];
178             s = SHL16(sign*pulsesAtOnce, yshift);
179             /* Temporary sums of the new pulse(s) */
180             Rxy = xy + MULT16_16(s,X[j]);
181             Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
182             Ryp = yp + MULT16_16(s, P[j]);
183
184             /* Compute the gain such that ||p + g*y|| = 1 */
185             g = MULT16_32_Q15(
186                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
187                                MULT16_16(ROUND16(Ryy,14),Rpp))
188                      - ROUND16(Ryp,14),
189                 celt_rcp(SHR32(Ryy,12)));
190             /* Knowing that gain, what's the error: (x-g*y)^2 
191                (result is negated and we discard x^2 because it's constant) */
192             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
193             scores[j] = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
194                     - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
195          }
196       }
197       
198       j = find_max32(scores, N);
199       is = signx[j]*pulsesAtOnce;
200       s = SHL16(is, yshift);
201
202       /* Updating the sums of the new pulse(s) */
203       xy = xy + MULT16_16(s,X[j]);
204       yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
205       yp = yp + MULT16_16(s, P[j]);
206
207       /* Only now that we've made the final choice, update y/iy */
208       y[j] += s;
209       iy[j] += is;
210       pulsesLeft -= pulsesAtOnce;
211    }
212    
213    encode_pulses(iy, N, K, enc);
214    
215    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
216    due to the recursive computation used in quantisation. */
217    mix_pitch_and_residual(iy, X, N, K, P);
218    RESTORE_STACK;
219 }
220
221
222 /** Decode pulse vector and combine the result with the pitch vector to produce
223     the final normalised signal in the current band. */
224 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
225 {
226    VARDECL(int, iy);
227    SAVE_STACK;
228    ALLOC(iy, N, int);
229    decode_pulses(iy, N, K, dec);
230    mix_pitch_and_residual(iy, X, N, K, P);
231    RESTORE_STACK;
232 }
233
234 #ifdef FIXED_POINT
235 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
236 #else
237 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
238 #endif
239
240 #define MAX_INTRA 32
241 #define LOG_MAX_INTRA 5
242       
243 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_enc *enc)
244 {
245    int i,j;
246    int best=0;
247    celt_word32_t best_score=0;
248    celt_word16_t s = 1;
249    int sign;
250    celt_word32_t E;
251    celt_word16_t pred_gain;
252    int max_pos = N0-N/B;
253    if (max_pos > MAX_INTRA)
254       max_pos = MAX_INTRA;
255
256    for (i=0;i<max_pos*B;i+=B)
257    {
258       celt_word32_t xy=0, yy=0;
259       celt_word32_t score;
260       /* If this doesn't generate a double-MAC on supported architectures, 
261          complain to your compilor vendor */
262       for (j=0;j<N;j++)
263       {
264          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
265          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
266       }
267       /* If you're really desperate for speed, just use xy as the score */
268       score = celt_div(MULT16_16(ROUND16(xy,14),ROUND16(xy,14)), ROUND16(yy,14));
269       if (score > best_score)
270       {
271          best_score = score;
272          best = i;
273          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
274          s = ROUND16(xy,14);
275       }
276    }
277    if (s<0)
278    {
279       s = -1;
280       sign = 1;
281    } else {
282       s = 1;
283       sign = 0;
284    }
285    /*printf ("%d %d ", sign, best);*/
286    ec_enc_bits(enc,sign,1);
287    if (max_pos == MAX_INTRA)
288       ec_enc_bits(enc,best/B,LOG_MAX_INTRA);
289    else
290       ec_enc_uint(enc,best/B,max_pos);
291
292    /*printf ("%d %f\n", best, best_score);*/
293    
294    if (K>10)
295       pred_gain = pg[10];
296    else
297       pred_gain = pg[K];
298    E = EPSILON;
299    for (j=0;j<N;j++)
300    {
301       P[j] = s*Y[best+N-j-1];
302       E = MAC16_16(E, P[j],P[j]);
303    }
304    /*pred_gain = pred_gain/sqrt(E);*/
305    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
306    for (j=0;j<N;j++)
307       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
308    if (K>0)
309    {
310       for (j=0;j<N;j++)
311          x[j] -= P[j];
312    } else {
313       for (j=0;j<N;j++)
314          x[j] = P[j];
315    }
316    /*printf ("quant ");*/
317    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
318
319 }
320
321 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, ec_dec *dec)
322 {
323    int j;
324    int sign;
325    celt_word16_t s;
326    int best;
327    celt_word32_t E;
328    celt_word16_t pred_gain;
329    int max_pos = N0-N/B;
330    if (max_pos > MAX_INTRA)
331       max_pos = MAX_INTRA;
332    
333    sign = ec_dec_bits(dec, 1);
334    if (sign == 0)
335       s = 1;
336    else
337       s = -1;
338    
339    if (max_pos == MAX_INTRA)
340       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
341    else
342       best = B*ec_dec_uint(dec, max_pos);
343    /*printf ("%d %d ", sign, best);*/
344
345    if (K>10)
346       pred_gain = pg[10];
347    else
348       pred_gain = pg[K];
349    E = EPSILON;
350    for (j=0;j<N;j++)
351    {
352       P[j] = s*Y[best+N-j-1];
353       E = MAC16_16(E, P[j],P[j]);
354    }
355    /*pred_gain = pred_gain/sqrt(E);*/
356    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
357    for (j=0;j<N;j++)
358       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
359    if (K==0)
360    {
361       for (j=0;j<N;j++)
362          x[j] = P[j];
363    }
364 }
365
366 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t * restrict P, int B, int N0, int Nmax)
367 {
368    int i, j;
369    celt_word32_t E;
370    celt_word16_t g;
371    
372    E = EPSILON;
373    if (N0 >= (Nmax>>1))
374    {
375       for (i=0;i<B;i++)
376       {
377          for (j=0;j<N/B;j++)
378          {
379             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
380             E += P[j*B+i]*P[j*B+i];
381          }
382       }
383    } else {
384       for (j=0;j<N;j++)
385       {
386          P[j] = Y[j];
387          E = MAC16_16(E, P[j],P[j]);
388       }
389    }
390    g = celt_rcp(SHL32(celt_sqrt(E),9));
391    for (j=0;j<N;j++)
392       P[j] = PSHR32(MULT16_16(g, P[j]),8);
393    for (j=0;j<N;j++)
394       x[j] = P[j];
395 }
396