Fix for folding_decision() in stereo mode and more cleaning up of the code
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41 #include "rate.h"
42
43 #ifndef M_PI
44 #define M_PI 3.141592653
45 #endif
46
47 static void exp_rotation(celt_norm_t *X, int len, int dir, int stride, int K)
48 {
49    int i, k, iter;
50    celt_word16_t c, s;
51    celt_word16_t gain, theta;
52    celt_norm_t *Xptr;
53    gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,len),(celt_word32_t)(3+len+6*K));
54    /* FIXME: Make that HALF16 instead of HALF32 */
55    theta = SUB16(Q15ONE, HALF32(MULT16_16_Q15(gain,gain)));
56    /*if (len==30)
57    {
58    for (i=0;i<len;i++)
59    X[i] = 0;
60    X[14] = 1;
61 }*/ 
62    c = celt_cos_norm(EXTEND32(theta));
63    s = dir*celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
64    if (len > 8*stride)
65       stride *= len/(8*stride);
66    iter = 1;
67    for (k=0;k<iter;k++)
68    {
69       /* We could use MULT16_16_P15 instead of MULT16_16_Q15 for more accuracy, 
70       but at this point, I really don't think it's necessary */
71       Xptr = X;
72       for (i=0;i<len-stride;i++)
73       {
74          celt_norm_t x1, x2;
75          x1 = Xptr[0];
76          x2 = Xptr[stride];
77          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
78          *Xptr++      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
79       }
80       Xptr = &X[len-2*stride-1];
81       for (i=len-2*stride-1;i>=0;i--)
82       {
83          celt_norm_t x1, x2;
84          x1 = Xptr[0];
85          x2 = Xptr[stride];
86          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
87          *Xptr--      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
88       }
89    }
90    /*if (len==30)
91    {
92    for (i=0;i<len;i++)
93    printf ("%f ", X[i]);
94    printf ("\n");
95    exit(0);
96 }*/
97 }
98
99
100 /** Takes the pitch vector and the decoded residual vector, computes the gain
101     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
102 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K)
103 {
104    int i;
105    celt_word32_t Ryy;
106    celt_word32_t g;
107    VARDECL(celt_norm_t, y);
108 #ifdef FIXED_POINT
109    int yshift;
110 #endif
111    SAVE_STACK;
112 #ifdef FIXED_POINT
113    yshift = 13-celt_ilog2(K);
114 #endif
115    ALLOC(y, N, celt_norm_t);
116
117    i=0;
118    Ryy = 0;
119    do {
120       y[i] = SHL16(iy[i],yshift);
121       Ryy = MAC16_16(Ryy, y[i], y[i]);
122    } while (++i < N);
123
124    g = MULT16_32_Q15(celt_sqrt(Ryy), celt_rcp(SHR32(Ryy,9)));
125
126    i=0;
127    do 
128       X[i] = ROUND16(MULT16_16(y[i], g),11);
129    while (++i < N);
130
131    RESTORE_STACK;
132 }
133
134
135 void alg_quant(celt_norm_t *X, int N, int K, int spread, ec_enc *enc)
136 {
137    VARDECL(celt_norm_t, y);
138    VARDECL(int, iy);
139    VARDECL(celt_word16_t, signx);
140    int j, is;
141    celt_word16_t s;
142    int pulsesLeft;
143    celt_word32_t sum;
144    celt_word32_t xy, yy;
145    int N_1; /* Inverse of N, in Q14 format (even for float) */
146 #ifdef FIXED_POINT
147    int yshift;
148 #endif
149    SAVE_STACK;
150
151    K = get_pulses(K);
152 #ifdef FIXED_POINT
153    yshift = 13-celt_ilog2(K);
154 #endif
155
156    ALLOC(y, N, celt_norm_t);
157    ALLOC(iy, N, int);
158    ALLOC(signx, N, celt_word16_t);
159    N_1 = 512/N;
160    
161    if (spread)
162       exp_rotation(X, N, 1, spread, K);
163
164    sum = 0;
165    j=0; do {
166       if (X[j]>0)
167          signx[j]=1;
168       else {
169          signx[j]=-1;
170          X[j]=-X[j];
171       }
172       iy[j] = 0;
173       y[j] = 0;
174    } while (++j<N);
175
176    celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
177
178    xy = yy = 0;
179
180    pulsesLeft = K;
181
182    /* Do a pre-search by projecting on the pyramid */
183    if (K > (N>>1))
184    {
185       celt_word16_t rcp;
186       sum=0;
187       j=0; do {
188          sum += X[j];
189       }  while (++j<N);
190
191 #ifdef FIXED_POINT
192       if (sum <= K)
193 #else
194       if (sum <= EPSILON)
195 #endif
196       {
197          X[0] = QCONST16(1.f,14);
198          j=1; do
199             X[j]=0;
200          while (++j<N);
201          sum = QCONST16(1.f,14);
202       }
203       /* Do we have sufficient accuracy here? */
204       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
205       j=0; do {
206 #ifdef FIXED_POINT
207          /* It's really important to round *towards zero* here */
208          iy[j] = MULT16_16_Q15(X[j],rcp);
209 #else
210          iy[j] = floor(rcp*X[j]);
211 #endif
212          y[j] = SHL16(iy[j],yshift);
213          yy = MAC16_16(yy, y[j],y[j]);
214          xy = MAC16_16(xy, X[j],y[j]);
215          y[j] *= 2;
216          pulsesLeft -= iy[j];
217       }  while (++j<N);
218    }
219    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
220
221    while (pulsesLeft > 0)
222    {
223       int pulsesAtOnce=1;
224       int best_id;
225       celt_word16_t magnitude;
226       celt_word32_t best_num = -VERY_LARGE16;
227       celt_word16_t best_den = 0;
228 #ifdef FIXED_POINT
229       int rshift;
230 #endif
231       /* Decide on how many pulses to find at once */
232       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
233       if (pulsesAtOnce<1)
234          pulsesAtOnce = 1;
235 #ifdef FIXED_POINT
236       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
237 #endif
238       magnitude = SHL16(pulsesAtOnce, yshift);
239
240       best_id = 0;
241       /* The squared magnitude term gets added anyway, so we might as well 
242          add it outside the loop */
243       yy = MAC16_16(yy, magnitude,magnitude);
244       /* Choose between fast and accurate strategy depending on where we are in the search */
245          /* This should ensure that anything we can process will have a better score */
246       j=0;
247       do {
248          celt_word16_t Rxy, Ryy;
249          /* Select sign based on X[j] alone */
250          s = magnitude;
251          /* Temporary sums of the new pulse(s) */
252          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
253          /* We're multiplying y[j] by two so we don't have to do it here */
254          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
255             
256             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
257          Rxy is positive because the sign is pre-computed) */
258          Rxy = MULT16_16_Q15(Rxy,Rxy);
259             /* The idea is to check for num/den >= best_num/best_den, but that way
260          we can do it without any division */
261          /* OPT: Make sure to use conditional moves here */
262          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
263          {
264             best_den = Ryy;
265             best_num = Rxy;
266             best_id = j;
267          }
268       } while (++j<N);
269       
270       j = best_id;
271       is = pulsesAtOnce;
272       s = SHL16(is, yshift);
273
274       /* Updating the sums of the new pulse(s) */
275       xy = xy + MULT16_16(s,X[j]);
276       /* We're multiplying y[j] by two so we don't have to do it here */
277       yy = yy + MULT16_16(s,y[j]);
278
279       /* Only now that we've made the final choice, update y/iy */
280       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
281       y[j] += 2*s;
282       iy[j] += is;
283       pulsesLeft -= pulsesAtOnce;
284    }
285    j=0;
286    do {
287       X[j] = MULT16_16(signx[j],X[j]);
288       if (signx[j] < 0)
289          iy[j] = -iy[j];
290    } while (++j<N);
291    encode_pulses(iy, N, K, enc);
292    
293    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
294    due to the recursive computation used in quantisation. */
295    mix_pitch_and_residual(iy, X, N, K);
296    if (spread)
297       exp_rotation(X, N, -1, spread, K);
298    RESTORE_STACK;
299 }
300
301
302 /** Decode pulse vector and combine the result with the pitch vector to produce
303     the final normalised signal in the current band. */
304 void alg_unquant(celt_norm_t *X, int N, int K, int spread, ec_dec *dec)
305 {
306    VARDECL(int, iy);
307    SAVE_STACK;
308    K = get_pulses(K);
309    ALLOC(iy, N, int);
310    decode_pulses(iy, N, K, dec);
311    mix_pitch_and_residual(iy, X, N, K);
312    if (spread)
313       exp_rotation(X, N, -1, spread, K);
314    RESTORE_STACK;
315 }
316
317 celt_word16_t renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
318 {
319    int i;
320    celt_word32_t E = EPSILON;
321    celt_word16_t rE;
322    celt_word16_t g;
323    celt_norm_t *xptr = X;
324    for (i=0;i<N;i++)
325    {
326       E = MAC16_16(E, *xptr, *xptr);
327       xptr += stride;
328    }
329
330    rE = celt_sqrt(E);
331 #ifdef FIXED_POINT
332    if (rE <= 128)
333       g = Q15ONE;
334    else
335 #endif
336       g = MULT16_16_Q15(value,celt_rcp(SHL32(rE,9)));
337    xptr = X;
338    for (i=0;i<N;i++)
339    {
340       *xptr = PSHR32(MULT16_16(g, *xptr),8);
341       xptr += stride;
342    }
343    return rE;
344 }
345
346 static void fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
347 {
348    int j;
349    int id = N0 % B;
350    /* Here, we assume that id will never be greater than N0, i.e. that 
351       no band is wider than N0. In the unlikely case it happens, we set
352       everything to zero */
353    /*{
354            int offset = (N0*C - (id+C*N))/2;
355            if (offset > C*N0/16)
356                    offset = C*N0/16;
357            offset -= offset % (C*B);
358            if (offset < 0)
359                    offset = 0;
360            //printf ("%d\n", offset);
361            id += offset;
362    }*/
363    if (id+N>N0)
364       for (j=0;j<N;j++)
365          P[j] = 0;
366    else
367       for (j=0;j<N;j++)
368          P[j] = Y[id++];
369 }
370
371 void intra_fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
372 {
373    fold(m, N, Y, P, N0, B);
374    renormalise_vector(P, Q15ONE, N, 1);
375 }
376