3e4de85c007bd54154f2f807d1fe9861a51d1187
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71    the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SHL16(iy[i],yshift);
74
75    /* Recompute after the projection (I think it's right) */
76    Ryp = 0;
77    for (i=0;i<N;i++)
78       Ryp = MAC16_16(Ryp,y[i],P[i]);
79
80    Ryy = 0;
81    for (i=0;i<N;i++)
82       Ryy = MAC16_16(Ryy, y[i],y[i]);
83
84    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
85    g = MULT16_32_Q15(
86             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
87                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
88             - ROUND16(Ryp,14),
89        celt_rcp(SHR32(Ryy,9)));
90
91    for (i=0;i<N;i++)
92       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
93    RESTORE_STACK;
94 }
95
96 /** All the info necessary to keep track of a hypothesis during the search */
97 struct NBest {
98    celt_word32_t score;
99    int sign;
100    int pos;
101    int orig;
102    celt_word32_t xy;
103    celt_word32_t yy;
104    celt_word32_t yp;
105 };
106
107 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
108 {
109    VARDECL(celt_norm_t, _y);
110    VARDECL(celt_norm_t, _ny);
111    VARDECL(int, _iy);
112    VARDECL(int, _iny);
113    celt_norm_t *y, *ny;
114    int *iy, *iny;
115    int i, j;
116    int pulsesLeft;
117    celt_word32_t xy, yy, yp;
118    struct NBest nbest;
119    celt_word32_t Rpp=0, Rxp=0;
120 #ifdef FIXED_POINT
121    int yshift;
122 #endif
123    SAVE_STACK;
124
125 #ifdef FIXED_POINT
126    yshift = 14-EC_ILOG(K);
127 #endif
128
129    ALLOC(_y, N, celt_norm_t);
130    ALLOC(_ny, N, celt_norm_t);
131    ALLOC(_iy, N, int);
132    ALLOC(_iny, N, int);
133    y = _y;
134    ny = _ny;
135    iy = _iy;
136    iny = _iny;
137    
138    for (j=0;j<N;j++)
139    {
140       Rpp = MAC16_16(Rpp, P[j],P[j]);
141       Rxp = MAC16_16(Rxp, X[j],P[j]);
142    }
143    Rpp = ROUND16(Rpp, NORM_SHIFT);
144    Rxp = ROUND16(Rxp, NORM_SHIFT);
145
146    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
147
148    for (i=0;i<N;i++)
149       y[i] = 0;
150    for (i=0;i<N;i++)
151       iy[i] = 0;
152    xy = yy = yp = 0;
153
154    pulsesLeft = K;
155    while (pulsesLeft > 0)
156    {
157       int pulsesAtOnce=1;
158       
159       /* Decide on how many pulses to find at once */
160       pulsesAtOnce = pulsesLeft/N;
161       if (pulsesAtOnce<1)
162          pulsesAtOnce = 1;
163       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
164
165       nbest.score = -VERY_LARGE32;
166
167       for (j=0;j<N;j++)
168       {
169          int sign;
170          /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
171          celt_word32_t Rxy, Ryy, Ryp;
172          celt_word32_t score;
173          celt_word32_t g;
174          celt_word16_t s;
175          
176          /* Select sign based on X[j] alone */
177          if (X[j]>0) sign=1; else sign=-1;
178          s = SHL16(sign*pulsesAtOnce, yshift);
179
180          /* Updating the sums of the new pulse(s) */
181          Rxy = xy + MULT16_16(s,X[j]);
182          Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
183          Ryp = yp + MULT16_16(s, P[j]);
184          
185          if (pulsesLeft>1)
186          {
187             score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
188          } else
189          {
190             /* Compute the gain such that ||p + g*y|| = 1 */
191             g = MULT16_32_Q15(
192                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
193                                MULT16_16(ROUND16(Ryy,14),Rpp))
194                      - ROUND16(Ryp,14),
195                 celt_rcp(SHR32(Ryy,12)));
196             /* Knowing that gain, what's the error: (x-g*y)^2 
197                (result is negated and we discard x^2 because it's constant) */
198             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
199             score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
200                     - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
201          }
202          
203          if (score>nbest.score)
204          {
205             nbest.score = score;
206             nbest.pos = j;
207             nbest.orig = 0;
208             nbest.sign = sign;
209             nbest.xy = Rxy;
210             nbest.yy = Ryy;
211             nbest.yp = Ryp;
212          }
213       }
214
215       celt_assert2(nbest[0]->score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
216
217       /* Only now that we've made the final choice, update ny/iny and others */
218       {
219          int n;
220          int is;
221          celt_norm_t s;
222          is = nbest.sign*pulsesAtOnce;
223          s = SHL16(is, yshift);
224          for (n=0;n<N;n++)
225             ny[n] = y[n];
226          ny[nbest.pos] += s;
227
228          for (n=0;n<N;n++)
229             iny[n] = iy[n];
230          iny[nbest.pos] += is;
231
232          xy = nbest.xy;
233          yy = nbest.yy;
234          yp = nbest.yp;
235       }
236       /* Swap ny/iny with y/iy */
237       {
238          celt_norm_t *tmp_ny;
239          int *tmp_iny;
240
241          tmp_ny = ny;
242          ny = y;
243          y = tmp_ny;
244          tmp_iny = iny;
245          iny = iy;
246          iy = tmp_iny;
247       }
248       pulsesLeft -= pulsesAtOnce;
249    }
250    
251    encode_pulses(iy, N, K, enc);
252    
253    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
254    due to the recursive computation used in quantisation. */
255    mix_pitch_and_residual(iy, X, N, K, P);
256    RESTORE_STACK;
257 }
258
259
260 /** Decode pulse vector and combine the result with the pitch vector to produce
261     the final normalised signal in the current band. */
262 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
263 {
264    VARDECL(int, iy);
265    SAVE_STACK;
266    ALLOC(iy, N, int);
267    decode_pulses(iy, N, K, dec);
268    mix_pitch_and_residual(iy, X, N, K, P);
269    RESTORE_STACK;
270 }
271
272 #ifdef FIXED_POINT
273 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
274 #else
275 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
276 #endif
277
278 #define MAX_INTRA 32
279 #define LOG_MAX_INTRA 5
280       
281 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
282 {
283    int i,j;
284    int best=0;
285    celt_word32_t best_score=0;
286    celt_word16_t s = 1;
287    int sign;
288    celt_word32_t E;
289    celt_word16_t pred_gain;
290    int max_pos = N0-N/B;
291    if (max_pos > MAX_INTRA)
292       max_pos = MAX_INTRA;
293
294    for (i=0;i<max_pos*B;i+=B)
295    {
296       celt_word32_t xy=0, yy=0;
297       celt_word32_t score;
298       /* If this doesn't generate a double-MAC on supported architectures, 
299          complain to your compilor vendor */
300       for (j=0;j<N;j++)
301       {
302          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
303          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
304       }
305       /* If you're really desperate for speed, just use xy as the score */
306       score = celt_div(MULT16_16(ROUND16(xy,14),ROUND16(xy,14)), ROUND16(yy,14));
307       if (score > best_score)
308       {
309          best_score = score;
310          best = i;
311          /* Store xy as the sign. We'll normalise it to +/- 1 later. */
312          s = ROUND16(xy,14);
313       }
314    }
315    if (s<0)
316    {
317       s = -1;
318       sign = 1;
319    } else {
320       s = 1;
321       sign = 0;
322    }
323    /*printf ("%d %d ", sign, best);*/
324    ec_enc_bits(enc,sign,1);
325    if (max_pos == MAX_INTRA)
326       ec_enc_bits(enc,best/B,LOG_MAX_INTRA);
327    else
328       ec_enc_uint(enc,best/B,max_pos);
329
330    /*printf ("%d %f\n", best, best_score);*/
331    
332    if (K>10)
333       pred_gain = pg[10];
334    else
335       pred_gain = pg[K];
336    E = EPSILON;
337    for (j=0;j<N;j++)
338    {
339       P[j] = s*Y[best+N-j-1];
340       E = MAC16_16(E, P[j],P[j]);
341    }
342    /*pred_gain = pred_gain/sqrt(E);*/
343    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
344    for (j=0;j<N;j++)
345       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
346    if (K>0)
347    {
348       for (j=0;j<N;j++)
349          x[j] -= P[j];
350    } else {
351       for (j=0;j<N;j++)
352          x[j] = P[j];
353    }
354    /*printf ("quant ");*/
355    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
356
357 }
358
359 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
360 {
361    int j;
362    int sign;
363    celt_word16_t s;
364    int best;
365    celt_word32_t E;
366    celt_word16_t pred_gain;
367    int max_pos = N0-N/B;
368    if (max_pos > MAX_INTRA)
369       max_pos = MAX_INTRA;
370    
371    sign = ec_dec_bits(dec, 1);
372    if (sign == 0)
373       s = 1;
374    else
375       s = -1;
376    
377    if (max_pos == MAX_INTRA)
378       best = B*ec_dec_bits(dec, LOG_MAX_INTRA);
379    else
380       best = B*ec_dec_uint(dec, max_pos);
381    /*printf ("%d %d ", sign, best);*/
382
383    if (K>10)
384       pred_gain = pg[10];
385    else
386       pred_gain = pg[K];
387    E = EPSILON;
388    for (j=0;j<N;j++)
389    {
390       P[j] = s*Y[best+N-j-1];
391       E = MAC16_16(E, P[j],P[j]);
392    }
393    /*pred_gain = pred_gain/sqrt(E);*/
394    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
395    for (j=0;j<N;j++)
396       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
397    if (K==0)
398    {
399       for (j=0;j<N;j++)
400          x[j] = P[j];
401    }
402 }
403
404 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
405 {
406    int i, j;
407    celt_word32_t E;
408    celt_word16_t g;
409    
410    E = EPSILON;
411    if (N0 >= (Nmax>>1))
412    {
413       for (i=0;i<B;i++)
414       {
415          for (j=0;j<N/B;j++)
416          {
417             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
418             E += P[j*B+i]*P[j*B+i];
419          }
420       }
421    } else {
422       for (j=0;j<N;j++)
423       {
424          P[j] = Y[j];
425          E = MAC16_16(E, P[j],P[j]);
426       }
427    }
428    g = celt_rcp(SHL32(celt_sqrt(E),9));
429    for (j=0;j<N;j++)
430       P[j] = PSHR32(MULT16_16(g, P[j]),8);
431    for (j=0;j<N;j++)
432       x[j] = P[j];
433 }
434