Experimental code to improve both the speed and accuracy of the VQ search
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41
42 /** Takes the pitch vector and the decoded residual vector, computes the gain
43     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
44 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
45 {
46    int i;
47    celt_word32_t Ryp, Ryy, Rpp;
48    celt_word16_t ryp, ryy, rpp;
49    celt_word32_t g;
50    VARDECL(celt_norm_t, y);
51 #ifdef FIXED_POINT
52    int yshift;
53 #endif
54    SAVE_STACK;
55 #ifdef FIXED_POINT
56    yshift = 13-celt_ilog2(K);
57 #endif
58    ALLOC(y, N, celt_norm_t);
59
60    /*for (i=0;i<N;i++)
61    printf ("%d ", iy[i]);*/
62    Rpp = 0;
63    i=0;
64    do {
65       Rpp = MAC16_16(Rpp,P[i],P[i]);
66       y[i] = SHL16(iy[i],yshift);
67    } while (++i < N);
68
69    Ryp = 0;
70    Ryy = 0;
71    /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
72    i=0;
73    do {
74       Ryp = MAC16_16(Ryp, y[i], P[i]);
75       Ryy = MAC16_16(Ryy, y[i], y[i]);
76    } while (++i < N);
77
78    ryp = ROUND16(Ryp,14);
79    ryy = ROUND16(Ryy,14);
80    rpp = ROUND16(Rpp,14);
81    /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
82    g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
83                      celt_rcp(SHR32(Ryy,9)));
84
85    i=0;
86    do 
87       X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
88    while (++i < N);
89
90    RESTORE_STACK;
91 }
92
93
94 void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, ec_enc *enc)
95 {
96    VARDECL(celt_norm_t, y);
97    VARDECL(int, iy);
98    VARDECL(celt_word16_t, signx);
99    int j, is;
100    celt_word16_t s;
101    int pulsesLeft;
102    celt_word32_t sum;
103    celt_word32_t xy, yy, yp;
104    celt_word16_t Rpp;
105    int N_1; /* Inverse of N, in Q14 format (even for float) */
106 #ifdef FIXED_POINT
107    int yshift;
108 #endif
109    SAVE_STACK;
110
111 #ifdef FIXED_POINT
112    yshift = 13-celt_ilog2(K);
113 #endif
114
115    ALLOC(y, N, celt_norm_t);
116    ALLOC(iy, N, int);
117    ALLOC(signx, N, celt_word16_t);
118    N_1 = 512/N;
119
120    sum = 0;
121    j=0; do {
122       X[j] -= P[j];
123       if (X[j]>0)
124          signx[j]=1;
125       else {
126          signx[j]=-1;
127          X[j]=-X[j];
128          P[j]=-P[j];
129       }
130       iy[j] = 0;
131       y[j] = 0;
132       sum = MAC16_16(sum, P[j],P[j]);
133    } while (++j<N);
134    Rpp = ROUND16(sum, NORM_SHIFT);
135
136    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
137
138    xy = yy = yp = 0;
139
140    pulsesLeft = K;
141 #if 0
142    if (K > (N>>1))
143    {
144       float sum=0;
145       j=0; do {
146          sum += X[j];
147       }  while (++j<N);
148       sum = K/(EPSILON+sum);
149       j=0; do {
150          iy[j] = floor(sum*X[j]);
151          y[j] = iy[j];
152          yy += y[j]*y[j];
153          xy += X[j]*y[j];
154          yp += P[j]*y[j];
155          y[j] *= 2;;
156          pulsesLeft -= iy[j];
157       }  while (++j<N);
158    }
159 #endif
160    while (pulsesLeft > 1)
161    {
162       int pulsesAtOnce=1;
163       int best_id;
164       celt_word16_t magnitude;
165       celt_word32_t best_num = -VERY_LARGE16;
166       celt_word16_t best_den = 0;
167 #ifdef FIXED_POINT
168       int rshift;
169 #endif
170       /* Decide on how many pulses to find at once */
171       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
172       if (pulsesAtOnce<1)
173          pulsesAtOnce = 1;
174 #ifdef FIXED_POINT
175       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
176 #endif
177       magnitude = SHL16(pulsesAtOnce, yshift);
178
179       best_id = 0;
180       /* The squared magnitude term gets added anyway, so we might as well 
181          add it outside the loop */
182       yy = MAC16_16(yy, magnitude,magnitude);
183       /* Choose between fast and accurate strategy depending on where we are in the search */
184          /* This should ensure that anything we can process will have a better score */
185       j=0;
186       do {
187          celt_word16_t Rxy, Ryy;
188          /* Select sign based on X[j] alone */
189          s = magnitude;
190          /* Temporary sums of the new pulse(s) */
191          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
192          /* We're multiplying y[j] by two so we don't have to do it here */
193          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
194             
195             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
196          Rxy is positive because the sign is pre-computed) */
197          Rxy = MULT16_16_Q15(Rxy,Rxy);
198             /* The idea is to check for num/den >= best_num/best_den, but that way
199          we can do it without any division */
200          /* OPT: Make sure to use conditional moves here */
201          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
202          {
203             best_den = Ryy;
204             best_num = Rxy;
205             best_id = j;
206          }
207       } while (++j<N);
208       
209       j = best_id;
210       is = pulsesAtOnce;
211       s = SHL16(is, yshift);
212
213       /* Updating the sums of the new pulse(s) */
214       xy = xy + MULT16_16(s,X[j]);
215       /* We're multiplying y[j] by two so we don't have to do it here */
216       yy = yy + MULT16_16(s,y[j]);
217       yp = yp + MULT16_16(s, P[j]);
218
219       /* Only now that we've made the final choice, update y/iy */
220       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
221       y[j] += 2*s;
222       iy[j] += is;
223       pulsesLeft -= pulsesAtOnce;
224    }
225    
226    {
227       celt_word16_t g;
228       celt_word16_t best_num = -VERY_LARGE16;
229       celt_word16_t best_den = 0;
230       int best_id = 0;
231       celt_word16_t magnitude = SHL16(1, yshift);
232
233       /* The squared magnitude term gets added anyway, so we might as well 
234       add it outside the loop */
235       yy = MAC16_16(yy, magnitude,magnitude);
236       j=0;
237       do {
238          celt_word16_t Rxy, Ryy, Ryp;
239          celt_word16_t num;
240          /* Select sign based on X[j] alone */
241          s = magnitude;
242          /* Temporary sums of the new pulse(s) */
243          Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
244          /* We're multiplying y[j] by two so we don't have to do it here */
245          Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
246          Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
247
248             /* Compute the gain such that ||p + g*y|| = 1 
249          ...but instead, we compute g*Ryy to avoid dividing */
250          g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
251             /* Knowing that gain, what's the error: (x-g*y)^2 
252          (result is negated and we discard x^2 because it's constant) */
253          /* score = 2*g*Rxy - g*g*Ryy;*/
254 #ifdef FIXED_POINT
255          /* No need to multiply Rxy by 2 because we did it earlier */
256          num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
257 #else
258          num = g*(2*Rxy-g);
259 #endif
260          if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
261          {
262             best_den = Ryy;
263             best_num = num;
264             best_id = j;
265          }
266       } while (++j<N);
267       iy[best_id] += 1;
268    }
269    j=0;
270    do {
271       P[j] = MULT16_16(signx[j],P[j]);
272       X[j] = MULT16_16(signx[j],X[j]);
273       if (signx[j] < 0)
274          iy[j] = -iy[j];
275    } while (++j<N);
276    encode_pulses(iy, N, K, enc);
277    
278    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
279    due to the recursive computation used in quantisation. */
280    mix_pitch_and_residual(iy, X, N, K, P);
281    RESTORE_STACK;
282 }
283
284
285 /** Decode pulse vector and combine the result with the pitch vector to produce
286     the final normalised signal in the current band. */
287 void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
288 {
289    VARDECL(int, iy);
290    SAVE_STACK;
291    ALLOC(iy, N, int);
292    decode_pulses(iy, N, K, dec);
293    mix_pitch_and_residual(iy, X, N, K, P);
294    RESTORE_STACK;
295 }
296
297 void renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
298 {
299    int i;
300    celt_word32_t E = EPSILON;
301    celt_word16_t g;
302    celt_norm_t *xptr = X;
303    for (i=0;i<N;i++)
304    {
305       E = MAC16_16(E, *xptr, *xptr);
306       xptr += stride;
307    }
308
309    g = MULT16_16_Q15(value,celt_rcp(SHL32(celt_sqrt(E),9)));
310    xptr = X;
311    for (i=0;i<N;i++)
312    {
313       *xptr = PSHR32(MULT16_16(g, *xptr),8);
314       xptr += stride;
315    }
316 }
317
318 static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
319 {
320    int j;
321    const int C = CHANNELS(m);
322    int id = N0 % (C*B);
323    /* Here, we assume that id will never be greater than N0, i.e. that 
324       no band is wider than N0. In the unlikely case it happens, we set
325       everything to zero */
326    if (id+C*N>N0)
327       for (j=0;j<C*N;j++)
328          P[j] = 0;
329    else
330       for (j=0;j<C*N;j++)
331          P[j] = Y[id++];
332 }
333
334 #define KGAIN 6
335
336 void intra_fold(const CELTMode *m, celt_norm_t * restrict x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
337 {
338    celt_word16_t pred_gain;
339    const int C = CHANNELS(m);
340
341    if (K==0)
342       pred_gain = Q15ONE;
343    else
344       pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+KGAIN*K));
345
346    fold(m, N, Y, P, N0, B);
347
348    renormalise_vector(P, pred_gain, C*N, 1);
349 }
350