Splitting transients in time domain
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "rate.h"
43
44 #ifndef M_PI
45 #define M_PI 3.141592653
46 #endif
47
48
49 static void exp_rotation1(celt_norm *X, int len, int dir, int stride, celt_word16 c, celt_word16 s)
50 {
51    int i;
52    celt_norm *Xptr;
53    if (dir>0)
54       s = -s;
55    Xptr = X;
56    for (i=0;i<len-stride;i++)
57    {
58       celt_norm x1, x2;
59       x1 = Xptr[0];
60       x2 = Xptr[stride];
61       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
62       *Xptr++      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
63    }
64    Xptr = &X[len-2*stride-1];
65    for (i=len-2*stride-1;i>=0;i--)
66    {
67       celt_norm x1, x2;
68       x1 = Xptr[0];
69       x2 = Xptr[stride];
70       Xptr[stride] = EXTRACT16(SHR32(MULT16_16(c,x2) + MULT16_16(s,x1), 15));
71       *Xptr--      = EXTRACT16(SHR32(MULT16_16(c,x1) - MULT16_16(s,x2), 15));
72    }
73 }
74
75 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K)
76 {
77    int i;
78    celt_word16 c, s;
79    celt_word16 gain, theta;
80    int stride2=0;
81    /*int i;
82    if (len>=30)
83    {
84       for (i=0;i<len;i++)
85          X[i] = 0;
86       X[14] = 1;
87       K=5;
88    }*/
89    if (2*K>=len)
90       return;
91    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(len+10*K));
92    /* FIXME: Make that HALF16 instead of HALF32 */
93    theta = HALF32(MULT16_16_Q15(gain,gain));
94
95    c = celt_cos_norm(EXTEND32(theta));
96    s = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
97
98    if (len>=8*stride)
99    {
100       stride2 = 1;
101       /* This is just a simple way of computing sqrt(len/stride) with rounding.
102          It's basically incrementing long as (stride2+0.5)^2 < len/stride.
103          I _think_ it is bit-exact */
104       while ((stride2*stride2+stride2)*stride + (stride>>2) < len)
105          stride2++;
106    }
107    len /= stride;
108    for (i=0;i<stride;i++)
109    {
110       if (dir < 0)
111       {
112          if (stride2)
113             exp_rotation1(X+i*len, len, dir, stride2, s, c);
114          exp_rotation1(X+i*len, len, dir, 1, c, s);
115       } else {
116          exp_rotation1(X+i*len, len, dir, 1, c, s);
117          if (stride2)
118             exp_rotation1(X+i*len, len, dir, stride2, s, c);
119       }
120    }
121    /*if (len>=30)
122    {
123       for (i=0;i<len;i++)
124          printf ("%f ", X[i]);
125       printf ("\n");
126       exit(0);
127    }*/
128 }
129
130 /** Takes the pitch vector and the decoded residual vector, computes the gain
131     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
132 static void normalise_residual(int * restrict iy, celt_norm * restrict X, int N, int K, celt_word32 Ryy)
133 {
134    int i;
135 #ifdef FIXED_POINT
136    int k;
137 #endif
138    celt_word32 t;
139    celt_word16 g;
140
141 #ifdef FIXED_POINT
142    k = celt_ilog2(Ryy)>>1;
143 #endif
144    t = VSHR32(Ryy, (k-7)<<1);
145    g = celt_rsqrt_norm(t);
146
147    i=0;
148    do
149       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
150    while (++i < N);
151 }
152
153 void alg_quant(celt_norm *X, int N, int K, int spread, celt_norm *lowband, int resynth, ec_enc *enc)
154 {
155    VARDECL(celt_norm, y);
156    VARDECL(int, iy);
157    VARDECL(celt_word16, signx);
158    int j, is;
159    celt_word16 s;
160    int pulsesLeft;
161    celt_word32 sum;
162    celt_word32 xy, yy;
163    int N_1; /* Inverse of N, in Q14 format (even for float) */
164 #ifdef FIXED_POINT
165    int yshift;
166 #endif
167    SAVE_STACK;
168
169    if (K==0)
170    {
171       if (lowband != NULL && resynth)
172       {
173          for (j=0;j<N;j++)
174             X[j] = lowband[j];
175          renormalise_vector(X, Q15ONE, N, 1);
176       } else {
177          /* This is important for encoding the side in stereo mode */
178          for (j=0;j<N;j++)
179             X[j] = 0;
180       }
181       return;
182    }
183    K = get_pulses(K);
184 #ifdef FIXED_POINT
185    yshift = 13-celt_ilog2(K);
186 #endif
187
188    ALLOC(y, N, celt_norm);
189    ALLOC(iy, N, int);
190    ALLOC(signx, N, celt_word16);
191    N_1 = 512/N;
192    
193    if (spread)
194       exp_rotation(X, N, 1, spread, K);
195
196    sum = 0;
197    j=0; do {
198       if (X[j]>0)
199          signx[j]=1;
200       else {
201          signx[j]=-1;
202          X[j]=-X[j];
203       }
204       iy[j] = 0;
205       y[j] = 0;
206    } while (++j<N);
207
208    xy = yy = 0;
209
210    pulsesLeft = K;
211
212    /* Do a pre-search by projecting on the pyramid */
213    if (K > (N>>1))
214    {
215       celt_word16 rcp;
216       j=0; do {
217          sum += X[j];
218       }  while (++j<N);
219
220 #ifdef FIXED_POINT
221       if (sum <= K)
222 #else
223       if (sum <= EPSILON)
224 #endif
225       {
226          X[0] = QCONST16(1.f,14);
227          j=1; do
228             X[j]=0;
229          while (++j<N);
230          sum = QCONST16(1.f,14);
231       }
232       /* Do we have sufficient accuracy here? */
233       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
234       j=0; do {
235 #ifdef FIXED_POINT
236          /* It's really important to round *towards zero* here */
237          iy[j] = MULT16_16_Q15(X[j],rcp);
238 #else
239          iy[j] = floor(rcp*X[j]);
240 #endif
241          y[j] = SHL16(iy[j],yshift);
242          yy = MAC16_16(yy, y[j],y[j]);
243          xy = MAC16_16(xy, X[j],y[j]);
244          y[j] *= 2;
245          pulsesLeft -= iy[j];
246       }  while (++j<N);
247    }
248    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
249
250    while (pulsesLeft > 0)
251    {
252       int pulsesAtOnce=1;
253       int best_id;
254       celt_word16 magnitude;
255       celt_word32 best_num = -VERY_LARGE16;
256       celt_word16 best_den = 0;
257 #ifdef FIXED_POINT
258       int rshift;
259 #endif
260       /* Decide on how many pulses to find at once */
261       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
262       if (pulsesAtOnce<1)
263          pulsesAtOnce = 1;
264 #ifdef FIXED_POINT
265       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
266 #endif
267       magnitude = SHL16(pulsesAtOnce, yshift);
268
269       best_id = 0;
270       /* The squared magnitude term gets added anyway, so we might as well 
271          add it outside the loop */
272       yy = MAC16_16(yy, magnitude,magnitude);
273       /* Choose between fast and accurate strategy depending on where we are in the search */
274          /* This should ensure that anything we can process will have a better score */
275       j=0;
276       do {
277          celt_word16 Rxy, Ryy;
278          /* Select sign based on X[j] alone */
279          s = magnitude;
280          /* Temporary sums of the new pulse(s) */
281          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
282          /* We're multiplying y[j] by two so we don't have to do it here */
283          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
284             
285             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
286          Rxy is positive because the sign is pre-computed) */
287          Rxy = MULT16_16_Q15(Rxy,Rxy);
288             /* The idea is to check for num/den >= best_num/best_den, but that way
289          we can do it without any division */
290          /* OPT: Make sure to use conditional moves here */
291          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
292          {
293             best_den = Ryy;
294             best_num = Rxy;
295             best_id = j;
296          }
297       } while (++j<N);
298       
299       j = best_id;
300       is = pulsesAtOnce;
301       s = SHL16(is, yshift);
302
303       /* Updating the sums of the new pulse(s) */
304       xy = xy + MULT16_16(s,X[j]);
305       /* We're multiplying y[j] by two so we don't have to do it here */
306       yy = yy + MULT16_16(s,y[j]);
307
308       /* Only now that we've made the final choice, update y/iy */
309       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
310       y[j] += 2*s;
311       iy[j] += is;
312       pulsesLeft -= pulsesAtOnce;
313    }
314    j=0;
315    do {
316       X[j] = MULT16_16(signx[j],X[j]);
317       if (signx[j] < 0)
318          iy[j] = -iy[j];
319    } while (++j<N);
320    encode_pulses(iy, N, K, enc);
321    
322    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
323    due to the recursive computation used in quantisation. */
324    if (resynth)
325    {
326       normalise_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
327       if (spread)
328          exp_rotation(X, N, -1, spread, K);
329    }
330    RESTORE_STACK;
331 }
332
333
334 /** Decode pulse vector and combine the result with the pitch vector to produce
335     the final normalised signal in the current band. */
336 void alg_unquant(celt_norm *X, int N, int K, int spread, celt_norm *lowband, ec_dec *dec)
337 {
338    int i;
339    celt_word32 Ryy;
340    VARDECL(int, iy);
341    SAVE_STACK;
342    if (K==0)
343    {
344       if (lowband != NULL)
345       {
346          for (i=0;i<N;i++)
347             X[i] = lowband[i];
348          renormalise_vector(X, Q15ONE, N, 1);
349       } else {
350          /* This is important for encoding the side in stereo mode */
351          for (i=0;i<N;i++)
352             X[i] = 0;
353       }
354       return;
355    }
356    K = get_pulses(K);
357    ALLOC(iy, N, int);
358    decode_pulses(iy, N, K, dec);
359    Ryy = 0;
360    i=0;
361    do {
362       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
363    } while (++i < N);
364    normalise_residual(iy, X, N, K, Ryy);
365    if (spread)
366       exp_rotation(X, N, -1, spread, K);
367    RESTORE_STACK;
368 }
369
370 celt_word16 renormalise_vector(celt_norm *X, celt_word16 value, int N, int stride)
371 {
372    int i;
373    celt_word32 E = EPSILON;
374    celt_word16 g;
375    celt_word32 t;
376    celt_norm *xptr = X;
377    for (i=0;i<N;i++)
378    {
379       E = MAC16_16(E, *xptr, *xptr);
380       xptr += stride;
381    }
382 #ifdef FIXED_POINT
383    int k = celt_ilog2(E)>>1;
384 #endif
385    t = VSHR32(E, (k-7)<<1);
386    g = MULT16_16_Q15(value, celt_rsqrt_norm(t));
387
388    xptr = X;
389    for (i=0;i<N;i++)
390    {
391       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
392       xptr += stride;
393    }
394    return celt_sqrt(E);
395 }
396