Removed the "pitch compression" in the residual quantisation. Also, removed
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector (non-compressed), 
43    applies the compression in the pitch direction, computes the gain that will
44    give ||p+g*y||=1 and mixes the residual with the pitch. */
45 static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const celt_norm_t *P)
46 {
47    int i;
48    celt_word32_t Ryp, Ryy, Rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 14-EC_ILOG(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    for (i=0;i<N;i++)
64       Rpp = MAC16_16(Rpp,P[i],P[i]);
65
66    Ryp = 0;
67    for (i=0;i<N;i++)
68       Ryp = MAC16_16(Ryp,SHL16(iy[i],yshift),P[i]);
69
70    /* Remove part of the pitch component to compute the real residual from
71    the encoded (int) one */
72    for (i=0;i<N;i++)
73       y[i] = SHL16(iy[i],yshift);
74
75    /* Recompute after the projection (I think it's right) */
76    Ryp = 0;
77    for (i=0;i<N;i++)
78       Ryp = MAC16_16(Ryp,y[i],P[i]);
79
80    Ryy = 0;
81    for (i=0;i<N;i++)
82       Ryy = MAC16_16(Ryy, y[i],y[i]);
83
84    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
85    g = MULT16_32_Q15(
86             celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
87                       MULT16_16(ROUND16(Ryy,14),ROUND16(Rpp,14)))
88             - ROUND16(Ryp,14),
89        celt_rcp(SHR32(Ryy,9)));
90
91    for (i=0;i<N;i++)
92       X[i] = P[i] + ROUND16(MULT16_16(y[i], g),11);
93    RESTORE_STACK;
94 }
95
96 /** All the info necessary to keep track of a hypothesis during the search */
97 struct NBest {
98    celt_word32_t score;
99    int sign;
100    int pos;
101    int orig;
102    celt_word32_t xy;
103    celt_word32_t yy;
104    celt_word32_t yp;
105 };
106
107 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
108 {
109    VARDECL(celt_norm_t, _y);
110    VARDECL(celt_norm_t, _ny);
111    VARDECL(int, _iy);
112    VARDECL(int, _iny);
113    celt_norm_t *y, *ny;
114    int *iy, *iny;
115    int i, j;
116    int pulsesLeft;
117    celt_word32_t xy, yy, yp;
118    struct NBest nbest;
119    celt_word32_t Rpp=0, Rxp=0;
120 #ifdef FIXED_POINT
121    int yshift;
122 #endif
123    SAVE_STACK;
124
125 #ifdef FIXED_POINT
126    yshift = 14-EC_ILOG(K);
127 #endif
128
129    ALLOC(_y, N, celt_norm_t);
130    ALLOC(_ny, N, celt_norm_t);
131    ALLOC(_iy, N, int);
132    ALLOC(_iny, N, int);
133    y = _y;
134    ny = _ny;
135    iy = _iy;
136    iny = _iny;
137    
138    for (j=0;j<N;j++)
139    {
140       Rpp = MAC16_16(Rpp, P[j],P[j]);
141       Rxp = MAC16_16(Rxp, X[j],P[j]);
142    }
143    Rpp = ROUND16(Rpp, NORM_SHIFT);
144    Rxp = ROUND16(Rxp, NORM_SHIFT);
145
146    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
147
148    for (i=0;i<N;i++)
149       y[i] = 0;
150    for (i=0;i<N;i++)
151       iy[i] = 0;
152    xy = yy = yp = 0;
153
154    pulsesLeft = K;
155    while (pulsesLeft > 0)
156    {
157       int pulsesAtOnce=1;
158       
159       /* Decide on how many pulses to find at once */
160       pulsesAtOnce = pulsesLeft/N;
161       if (pulsesAtOnce<1)
162          pulsesAtOnce = 1;
163       /*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
164
165       nbest.score = -VERY_LARGE32;
166
167       for (j=0;j<N;j++)
168       {
169          int sign;
170          /*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
171          celt_word32_t Rxy, Ryy, Ryp;
172          celt_word32_t score;
173          celt_word32_t g;
174          celt_word16_t s;
175          
176          /* Select sign based on X[j] alone */
177          if (X[j]>0) sign=1; else sign=-1;
178          s = SHL16(sign*pulsesAtOnce, yshift);
179
180          /* Updating the sums of the new pulse(s) */
181          Rxy = xy + MULT16_16(s,X[j]);
182          Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
183          Ryp = yp + MULT16_16(s, P[j]);
184          
185          if (pulsesLeft>1)
186          {
187             score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
188          } else
189          {
190             /* Compute the gain such that ||p + g*y|| = 1 */
191             g = MULT16_32_Q15(
192                      celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
193                                MULT16_16(ROUND16(Ryy,14),Rpp))
194                      - ROUND16(Ryp,14),
195                 celt_rcp(SHR32(Ryy,12)));
196             /* Knowing that gain, what's the error: (x-g*y)^2 
197                (result is negated and we discard x^2 because it's constant) */
198             /* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
199             score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
200                     - MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
201          }
202          
203          if (score>nbest.score)
204          {
205             nbest.score = score;
206             nbest.pos = j;
207             nbest.orig = 0;
208             nbest.sign = sign;
209             nbest.xy = Rxy;
210             nbest.yy = Ryy;
211             nbest.yp = Ryp;
212          }
213       }
214
215       celt_assert2(nbest[0]->score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
216
217       /* Only now that we've made the final choice, update ny/iny and others */
218       {
219          int n;
220          int is;
221          celt_norm_t s;
222          is = nbest.sign*pulsesAtOnce;
223          s = SHL16(is, yshift);
224          for (n=0;n<N;n++)
225             ny[n] = y[n];
226          ny[nbest.pos] += s;
227
228          for (n=0;n<N;n++)
229             iny[n] = iy[n];
230          iny[nbest.pos] += is;
231
232          xy = nbest.xy;
233          yy = nbest.yy;
234          yp = nbest.yp;
235       }
236       /* Swap ny/iny with y/iy */
237       {
238          celt_norm_t *tmp_ny;
239          int *tmp_iny;
240
241          tmp_ny = ny;
242          ny = y;
243          y = tmp_ny;
244          tmp_iny = iny;
245          iny = iy;
246          iy = tmp_iny;
247       }
248       pulsesLeft -= pulsesAtOnce;
249    }
250    
251    encode_pulses(iy, N, K, enc);
252    
253    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
254    due to the recursive computation used in quantisation. */
255    mix_pitch_and_residual(iy, X, N, K, P);
256    RESTORE_STACK;
257 }
258
259
260 /** Decode pulse vector and combine the result with the pitch vector to produce
261     the final normalised signal in the current band. */
262 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
263 {
264    VARDECL(int, iy);
265    SAVE_STACK;
266    ALLOC(iy, N, int);
267    decode_pulses(iy, N, K, dec);
268    mix_pitch_and_residual(iy, X, N, K, P);
269    RESTORE_STACK;
270 }
271
272 #ifdef FIXED_POINT
273 static const celt_word16_t pg[11] = {32767, 24576, 21299, 19661, 19661, 19661, 18022, 18022, 16384, 16384, 16384};
274 #else
275 static const celt_word16_t pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
276 #endif
277
278 void intra_prediction(celt_norm_t *x, celt_mask_t *W, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_enc *enc)
279 {
280    int i,j;
281    int best=0;
282    celt_word32_t best_score=0;
283    celt_word16_t s = 1;
284    int sign;
285    celt_word32_t E;
286    celt_word16_t pred_gain;
287    int max_pos = N0-N/B;
288    if (max_pos > 32)
289       max_pos = 32;
290
291    for (i=0;i<max_pos*B;i+=B)
292    {
293       celt_word32_t xy=0, yy=0;
294       celt_word32_t score;
295       for (j=0;j<N;j++)
296       {
297          xy = MAC16_16(xy, x[j], Y[i+N-j-1]);
298          yy = MAC16_16(yy, Y[i+N-j-1], Y[i+N-j-1]);
299       }
300       score = celt_div(MULT16_16(ROUND16(xy,14),ROUND16(xy,14)), ROUND16(yy,14));
301       if (score > best_score)
302       {
303          best_score = score;
304          best = i;
305          if (xy>0)
306             s = 1;
307          else
308             s = -1;
309       }
310    }
311    if (s<0)
312       sign = 1;
313    else
314       sign = 0;
315    /*printf ("%d %d ", sign, best);*/
316    ec_enc_uint(enc,sign,2);
317    ec_enc_uint(enc,best/B,max_pos);
318    /*printf ("%d %f\n", best, best_score);*/
319    
320    if (K>10)
321       pred_gain = pg[10];
322    else
323       pred_gain = pg[K];
324    E = EPSILON;
325    for (j=0;j<N;j++)
326    {
327       P[j] = s*Y[best+N-j-1];
328       E = MAC16_16(E, P[j],P[j]);
329    }
330    /*pred_gain = pred_gain/sqrt(E);*/
331    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
332    for (j=0;j<N;j++)
333       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
334    if (K>0)
335    {
336       for (j=0;j<N;j++)
337          x[j] -= P[j];
338    } else {
339       for (j=0;j<N;j++)
340          x[j] = P[j];
341    }
342    /*printf ("quant ");*/
343    /*for (j=0;j<N;j++) printf ("%f ", P[j]);*/
344
345 }
346
347 void intra_unquant(celt_norm_t *x, int N, int K, celt_norm_t *Y, celt_norm_t *P, int B, int N0, ec_dec *dec)
348 {
349    int j;
350    int sign;
351    celt_word16_t s;
352    int best;
353    celt_word32_t E;
354    celt_word16_t pred_gain;
355    int max_pos = N0-N/B;
356    if (max_pos > 32)
357       max_pos = 32;
358    
359    sign = ec_dec_uint(dec, 2);
360    if (sign == 0)
361       s = 1;
362    else
363       s = -1;
364    
365    best = B*ec_dec_uint(dec, max_pos);
366    /*printf ("%d %d ", sign, best);*/
367
368    if (K>10)
369       pred_gain = pg[10];
370    else
371       pred_gain = pg[K];
372    E = EPSILON;
373    for (j=0;j<N;j++)
374    {
375       P[j] = s*Y[best+N-j-1];
376       E = MAC16_16(E, P[j],P[j]);
377    }
378    /*pred_gain = pred_gain/sqrt(E);*/
379    pred_gain = MULT16_16_Q15(pred_gain,celt_rcp(SHL32(celt_sqrt(E),9)));
380    for (j=0;j<N;j++)
381       P[j] = PSHR32(MULT16_16(pred_gain, P[j]),8);
382    if (K==0)
383    {
384       for (j=0;j<N;j++)
385          x[j] = P[j];
386    }
387 }
388
389 void intra_fold(celt_norm_t *x, int N, celt_norm_t *Y, celt_norm_t *P, int B, int N0, int Nmax)
390 {
391    int i, j;
392    celt_word32_t E;
393    celt_word16_t g;
394    
395    E = EPSILON;
396    if (N0 >= (Nmax>>1))
397    {
398       for (i=0;i<B;i++)
399       {
400          for (j=0;j<N/B;j++)
401          {
402             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
403             E += P[j*B+i]*P[j*B+i];
404          }
405       }
406    } else {
407       for (j=0;j<N;j++)
408       {
409          P[j] = Y[j];
410          E = MAC16_16(E, P[j],P[j]);
411       }
412    }
413    g = celt_rcp(SHL32(celt_sqrt(E),9));
414    for (j=0;j<N;j++)
415       P[j] = PSHR32(MULT16_16(g, P[j]),8);
416    for (j=0;j<N;j++)
417       x[j] = P[j];
418 }
419