New VQ search is now enabled by default after fixing the last remaining issues:
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word16_t ryp, ryy, rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 13-celt_ilog2(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    i=0;
64    do {
65       Rpp = MAC16_16(Rpp,P[i],P[i]);
66       y[i] = SHL16(iy[i],yshift);
67    } while (++i < N);
68
69    Ryp = 0;
70    Ryy = 0;
71    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
72    i=0;
73    do {
74       Ryp = MAC16_16(Ryp, y[i], P[i]);
75       Ryy = MAC16_16(Ryy, y[i], y[i]);
76    } while (++i < N);
77
78    ryp = ROUND16(Ryp,14);
79    ryy = ROUND16(Ryy,14);
80    rpp = ROUND16(Rpp,14);
81    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
82    g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
83                      celt_rcp(SHR32(Ryy,9)));
84
85    i=0;
86    do 
87       X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
88    while (++i < N);
89
90    RESTORE_STACK;
91 }
92
93
94 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, ec_enc *enc)
95 {
96    VARDECL(celt_norm_t, y);
97    VARDECL(int, iy);
98    VARDECL(celt_word16_t, signx);
99    int j, is;
100    celt_word16_t s;
101    int pulsesLeft;
102    celt_word32_t sum;
103    celt_word32_t xy, yy, yp;
104    celt_word16_t Rpp;
105    int N_1; /* Inverse of N, in Q14 format (even for float) */
106 #ifdef FIXED_POINT
107    int yshift;
108 #endif
109    SAVE_STACK;
110
111 #ifdef FIXED_POINT
112    yshift = 13-celt_ilog2(K);
113 #endif
114
115    ALLOC(y, N, celt_norm_t);
116    ALLOC(iy, N, int);
117    ALLOC(signx, N, celt_word16_t);
118    N_1 = 512/N;
119
120    sum = 0;
121    j=0; do {
122       X[j] -= P[j];
123       if (X[j]>0)
124          signx[j]=1;
125       else {
126          signx[j]=-1;
127          X[j]=-X[j];
128          P[j]=-P[j];
129       }
130       iy[j] = 0;
131       y[j] = 0;
132       sum = MAC16_16(sum, P[j],P[j]);
133    } while (++j<N);
134    Rpp = ROUND16(sum, NORM_SHIFT);
135
136    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
137
138    xy = yy = yp = 0;
139
140    pulsesLeft = K;
141
142    /* Do a pre-search by projecting on the pyramid */
143    if (K > (N>>1))
144    {
145       celt_word32_t sum=0;
146       celt_word16_t rcp;
147       j=0; do {
148          sum += X[j];
149       }  while (++j<N);
150       if (sum == 0)
151       {
152          X[0] = 16384;
153          sum = 16384;
154       }
155       /* Do we have sufficient accuracy here? */
156       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
157       /*rcp = DIV32(SHL32(EXTEND32(K-1),15),EPSILON+sum);*/
158       /*printf ("%d (%d %d)\n", rcp, N, K);*/
159       j=0; do {
160 #ifdef FIXED_POINT
161          /* It's really important to round *towards zero* here */
162          iy[j] = MULT16_16_Q15(X[j],rcp);
163 #else
164          iy[j] = floor(rcp*X[j]);
165 #endif
166          y[j] = SHL16(iy[j],yshift);
167          yy = MAC16_16(yy, y[j],y[j]);
168          xy = MAC16_16(xy, X[j],y[j]);
169          yp += P[j]*y[j];
170          y[j] *= 2;
171          pulsesLeft -= iy[j];
172       }  while (++j<N);
173    }
174    /*if (pulsesLeft > N+2)
175       printf ("%d / %d (%d)\n", pulsesLeft, K, N);*/
176    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
177
178    while (pulsesLeft > 1)
179    {
180       int pulsesAtOnce=1;
181       int best_id;
182       celt_word16_t magnitude;
183       celt_word32_t best_num = -VERY_LARGE16;
184       celt_word16_t best_den = 0;
185 #ifdef FIXED_POINT
186       int rshift;
187 #endif
188       /* Decide on how many pulses to find at once */
189       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
190       if (pulsesAtOnce<1)
191          pulsesAtOnce = 1;
192 #ifdef FIXED_POINT
193       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
194 #endif
195       magnitude = SHL16(pulsesAtOnce, yshift);
196
197       best_id = 0;
198       /* The squared magnitude term gets added anyway, so we might as well 
199          add it outside the loop */
200       yy = MAC16_16(yy, magnitude,magnitude);
201       /* Choose between fast and accurate strategy depending on where we are in the search */
202          /* This should ensure that anything we can process will have a better score */
203       j=0;
204       do {
205          celt_word16_t Rxy, Ryy;
206          /* Select sign based on X[j] alone */
207          s = magnitude;
208          /* Temporary sums of the new pulse(s) */
209          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
210          /* We're multiplying y[j] by two so we don't have to do it here */
211          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
212             
213             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
214          Rxy is positive because the sign is pre-computed) */
215          Rxy = MULT16_16_Q15(Rxy,Rxy);
216             /* The idea is to check for num/den >= best_num/best_den, but that way
217          we can do it without any division */
218          /* OPT: Make sure to use conditional moves here */
219          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
220          {
221             best_den = Ryy;
222             best_num = Rxy;
223             best_id = j;
224          }
225       } while (++j<N);
226       
227       j = best_id;
228       is = pulsesAtOnce;
229       s = SHL16(is, yshift);
230
231       /* Updating the sums of the new pulse(s) */
232       xy = xy + MULT16_16(s,X[j]);
233       /* We're multiplying y[j] by two so we don't have to do it here */
234       yy = yy + MULT16_16(s,y[j]);
235       yp = yp + MULT16_16(s, P[j]);
236
237       /* Only now that we've made the final choice, update y/iy */
238       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
239       y[j] += 2*s;
240       iy[j] += is;
241       pulsesLeft -= pulsesAtOnce;
242    }
243    
244    if (pulsesLeft > 0)
245    {
246       celt_word16_t g;
247       celt_word16_t best_num = -VERY_LARGE16;
248       celt_word16_t best_den = 0;
249       int best_id = 0;
250       celt_word16_t magnitude = SHL16(1, yshift);
251
252       /* The squared magnitude term gets added anyway, so we might as well 
253       add it outside the loop */
254       yy = MAC16_16(yy, magnitude,magnitude);
255       j=0;
256       do {
257          celt_word16_t Rxy, Ryy, Ryp;
258          celt_word16_t num;
259          /* Select sign based on X[j] alone */
260          s = magnitude;
261          /* Temporary sums of the new pulse(s) */
262          Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
263          /* We're multiplying y[j] by two so we don't have to do it here */
264          Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
265          Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
266
267             /* Compute the gain such that ||p + g*y|| = 1 
268          ...but instead, we compute g*Ryy to avoid dividing */
269          g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
270             /* Knowing that gain, what's the error: (x-g*y)^2 
271          (result is negated and we discard x^2 because it's constant) */
272          /* score = 2*g*Rxy - g*g*Ryy;*/
273 #ifdef FIXED_POINT
274          /* No need to multiply Rxy by 2 because we did it earlier */
275          num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
276 #else
277          num = g*(2*Rxy-g);
278 #endif
279          if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
280          {
281             best_den = Ryy;
282             best_num = num;
283             best_id = j;
284          }
285       } while (++j<N);
286       iy[best_id] += 1;
287    }
288    j=0;
289    do {
290       P[j] = MULT16_16(signx[j],P[j]);
291       X[j] = MULT16_16(signx[j],X[j]);
292       if (signx[j] < 0)
293          iy[j] = -iy[j];
294    } while (++j<N);
295    encode_pulses(iy, N, K, enc);
296    
297    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
298    due to the recursive computation used in quantisation. */
299    mix_pitch_and_residual(iy, X, N, K, P);
300    RESTORE_STACK;
301 }
302
303
304 /** Decode pulse vector and combine the result with the pitch vector to produce
305     the final normalised signal in the current band. */
306 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
307 {
308    VARDECL(int, iy);
309    SAVE_STACK;
310    ALLOC(iy, N, int);
311    decode_pulses(iy, N, K, dec);
312    mix_pitch_and_residual(iy, X, N, K, P);
313    RESTORE_STACK;
314 }
315
316 void renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
317 {
318    int i;
319    celt_word32_t E = EPSILON;
320    celt_word16_t g;
321    celt_norm_t *xptr = X;
322    for (i=0;i<N;i++)
323    {
324       E = MAC16_16(E, *xptr, *xptr);
325       xptr += stride;
326    }
327
328    g = MULT16_16_Q15(value,celt_rcp(SHL32(celt_sqrt(E),9)));
329    xptr = X;
330    for (i=0;i<N;i++)
331    {
332       *xptr = PSHR32(MULT16_16(g, *xptr),8);
333       xptr += stride;
334    }
335 }
336
337 static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
338 {
339    int j;
340    const int C = CHANNELS(m);
341    int id = N0 % (C*B);
342    /* Here, we assume that id will never be greater than N0, i.e. that 
343       no band is wider than N0. In the unlikely case it happens, we set
344       everything to zero */
345    if (id+C*N>N0)
346       for (j=0;j<C*N;j++)
347          P[j] = 0;
348    else
349       for (j=0;j<C*N;j++)
350          P[j] = Y[id++];
351 }
352
353 #define KGAIN 6
354
355 void intra_fold(const CELTMode *m, celt_norm_t * restrict x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
356 {
357    celt_word16_t pred_gain;
358    const int C = CHANNELS(m);
359
360    if (K==0)
361       pred_gain = Q15ONE;
362    else
363       pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+KGAIN*K));
364
365    fold(m, N, Y, P, N0, B);
366
367    renormalise_vector(P, pred_gain, C*N, 1);
368 }
369