Fix for some test program compat and an assertion that didn't make sense anymore
[opus.git] / libcelt / vq.c
1 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #ifdef HAVE_CONFIG_H
33 #include "config.h"
34 #endif
35
36 #include "mathops.h"
37 #include "cwrs.h"
38 #include "vq.h"
39 #include "arch.h"
40 #include "os_support.h"
41 #include "rate.h"
42
43 #ifndef M_PI
44 #define M_PI 3.141592653
45 #endif
46
47 static void exp_rotation(celt_norm_t *X, int len, int dir, int stride, int K)
48 {
49    int i, k, iter;
50    celt_word16_t c, s;
51    celt_word16_t gain, theta;
52    celt_norm_t *Xptr;
53    gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,len),(celt_word32_t)(3+len+6*K));
54    /* FIXME: Make that HALF16 instead of HALF32 */
55    theta = SUB16(Q15ONE, HALF32(MULT16_16_Q15(gain,gain)));
56    /*if (len==30)
57    {
58    for (i=0;i<len;i++)
59    X[i] = 0;
60    X[14] = 1;
61 }*/ 
62    c = celt_cos_norm(EXTEND32(theta));
63    s = dir*celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
64    if (len > 8*stride)
65       stride *= len/(8*stride);
66    iter = 1;
67    for (k=0;k<iter;k++)
68    {
69       /* We could use MULT16_16_P15 instead of MULT16_16_Q15 for more accuracy, 
70       but at this point, I really don't think it's necessary */
71       Xptr = X;
72       for (i=0;i<len-stride;i++)
73       {
74          celt_norm_t x1, x2;
75          x1 = Xptr[0];
76          x2 = Xptr[stride];
77          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
78          *Xptr++      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
79       }
80       Xptr = &X[len-2*stride-1];
81       for (i=len-2*stride-1;i>=0;i--)
82       {
83          celt_norm_t x1, x2;
84          x1 = Xptr[0];
85          x2 = Xptr[stride];
86          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
87          *Xptr--      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
88       }
89    }
90    /*if (len==30)
91    {
92    for (i=0;i<len;i++)
93    printf ("%f ", X[i]);
94    printf ("\n");
95    exit(0);
96 }*/
97 }
98
99
100 /** Takes the pitch vector and the decoded residual vector, computes the gain
101     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
102 static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K)
103 {
104    int i;
105    celt_word32_t Ryy;
106    celt_word32_t g;
107    VARDECL(celt_norm_t, y);
108 #ifdef FIXED_POINT
109    int yshift;
110 #endif
111    SAVE_STACK;
112 #ifdef FIXED_POINT
113    yshift = 13-celt_ilog2(K);
114 #endif
115    ALLOC(y, N, celt_norm_t);
116
117    i=0;
118    Ryy = 0;
119    do {
120       y[i] = SHL16(iy[i],yshift);
121       Ryy = MAC16_16(Ryy, y[i], y[i]);
122    } while (++i < N);
123
124    g = MULT16_32_Q15(celt_sqrt(Ryy), celt_rcp(SHR32(Ryy,9)));
125
126    i=0;
127    do 
128       X[i] = ROUND16(MULT16_16(y[i], g),11);
129    while (++i < N);
130
131    RESTORE_STACK;
132 }
133
134
135 void alg_quant(celt_norm_t *X, int N, int K, int spread, ec_enc *enc)
136 {
137    VARDECL(celt_norm_t, y);
138    VARDECL(int, iy);
139    VARDECL(celt_word16_t, signx);
140    int j, is;
141    celt_word16_t s;
142    int pulsesLeft;
143    celt_word32_t sum;
144    celt_word32_t xy, yy;
145    int N_1; /* Inverse of N, in Q14 format (even for float) */
146 #ifdef FIXED_POINT
147    int yshift;
148 #endif
149    SAVE_STACK;
150
151    K = get_pulses(K);
152 #ifdef FIXED_POINT
153    yshift = 13-celt_ilog2(K);
154 #endif
155
156    ALLOC(y, N, celt_norm_t);
157    ALLOC(iy, N, int);
158    ALLOC(signx, N, celt_word16_t);
159    N_1 = 512/N;
160    
161    if (spread)
162       exp_rotation(X, N, 1, spread, K);
163
164    sum = 0;
165    j=0; do {
166       if (X[j]>0)
167          signx[j]=1;
168       else {
169          signx[j]=-1;
170          X[j]=-X[j];
171       }
172       iy[j] = 0;
173       y[j] = 0;
174    } while (++j<N);
175
176    xy = yy = 0;
177
178    pulsesLeft = K;
179
180    /* Do a pre-search by projecting on the pyramid */
181    if (K > (N>>1))
182    {
183       celt_word16_t rcp;
184       sum=0;
185       j=0; do {
186          sum += X[j];
187       }  while (++j<N);
188
189 #ifdef FIXED_POINT
190       if (sum <= K)
191 #else
192       if (sum <= EPSILON)
193 #endif
194       {
195          X[0] = QCONST16(1.f,14);
196          j=1; do
197             X[j]=0;
198          while (++j<N);
199          sum = QCONST16(1.f,14);
200       }
201       /* Do we have sufficient accuracy here? */
202       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
203       j=0; do {
204 #ifdef FIXED_POINT
205          /* It's really important to round *towards zero* here */
206          iy[j] = MULT16_16_Q15(X[j],rcp);
207 #else
208          iy[j] = floor(rcp*X[j]);
209 #endif
210          y[j] = SHL16(iy[j],yshift);
211          yy = MAC16_16(yy, y[j],y[j]);
212          xy = MAC16_16(xy, X[j],y[j]);
213          y[j] *= 2;
214          pulsesLeft -= iy[j];
215       }  while (++j<N);
216    }
217    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
218
219    while (pulsesLeft > 0)
220    {
221       int pulsesAtOnce=1;
222       int best_id;
223       celt_word16_t magnitude;
224       celt_word32_t best_num = -VERY_LARGE16;
225       celt_word16_t best_den = 0;
226 #ifdef FIXED_POINT
227       int rshift;
228 #endif
229       /* Decide on how many pulses to find at once */
230       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
231       if (pulsesAtOnce<1)
232          pulsesAtOnce = 1;
233 #ifdef FIXED_POINT
234       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
235 #endif
236       magnitude = SHL16(pulsesAtOnce, yshift);
237
238       best_id = 0;
239       /* The squared magnitude term gets added anyway, so we might as well 
240          add it outside the loop */
241       yy = MAC16_16(yy, magnitude,magnitude);
242       /* Choose between fast and accurate strategy depending on where we are in the search */
243          /* This should ensure that anything we can process will have a better score */
244       j=0;
245       do {
246          celt_word16_t Rxy, Ryy;
247          /* Select sign based on X[j] alone */
248          s = magnitude;
249          /* Temporary sums of the new pulse(s) */
250          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
251          /* We're multiplying y[j] by two so we don't have to do it here */
252          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
253             
254             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
255          Rxy is positive because the sign is pre-computed) */
256          Rxy = MULT16_16_Q15(Rxy,Rxy);
257             /* The idea is to check for num/den >= best_num/best_den, but that way
258          we can do it without any division */
259          /* OPT: Make sure to use conditional moves here */
260          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
261          {
262             best_den = Ryy;
263             best_num = Rxy;
264             best_id = j;
265          }
266       } while (++j<N);
267       
268       j = best_id;
269       is = pulsesAtOnce;
270       s = SHL16(is, yshift);
271
272       /* Updating the sums of the new pulse(s) */
273       xy = xy + MULT16_16(s,X[j]);
274       /* We're multiplying y[j] by two so we don't have to do it here */
275       yy = yy + MULT16_16(s,y[j]);
276
277       /* Only now that we've made the final choice, update y/iy */
278       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
279       y[j] += 2*s;
280       iy[j] += is;
281       pulsesLeft -= pulsesAtOnce;
282    }
283    j=0;
284    do {
285       X[j] = MULT16_16(signx[j],X[j]);
286       if (signx[j] < 0)
287          iy[j] = -iy[j];
288    } while (++j<N);
289    encode_pulses(iy, N, K, enc);
290    
291    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
292    due to the recursive computation used in quantisation. */
293    mix_pitch_and_residual(iy, X, N, K);
294    if (spread)
295       exp_rotation(X, N, -1, spread, K);
296    RESTORE_STACK;
297 }
298
299
300 /** Decode pulse vector and combine the result with the pitch vector to produce
301     the final normalised signal in the current band. */
302 void alg_unquant(celt_norm_t *X, int N, int K, int spread, ec_dec *dec)
303 {
304    VARDECL(int, iy);
305    SAVE_STACK;
306    K = get_pulses(K);
307    ALLOC(iy, N, int);
308    decode_pulses(iy, N, K, dec);
309    mix_pitch_and_residual(iy, X, N, K);
310    if (spread)
311       exp_rotation(X, N, -1, spread, K);
312    RESTORE_STACK;
313 }
314
315 celt_word16_t renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
316 {
317    int i;
318    celt_word32_t E = EPSILON;
319    celt_word16_t rE;
320    celt_word16_t g;
321    celt_norm_t *xptr = X;
322    for (i=0;i<N;i++)
323    {
324       E = MAC16_16(E, *xptr, *xptr);
325       xptr += stride;
326    }
327
328    rE = celt_sqrt(E);
329 #ifdef FIXED_POINT
330    if (rE <= 128)
331       g = Q15ONE;
332    else
333 #endif
334       g = MULT16_16_Q15(value,celt_rcp(SHL32(rE,9)));
335    xptr = X;
336    for (i=0;i<N;i++)
337    {
338       *xptr = PSHR32(MULT16_16(g, *xptr),8);
339       xptr += stride;
340    }
341    return rE;
342 }
343
344 static void fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
345 {
346    int j;
347    int id = N0 % B;
348    /* Here, we assume that id will never be greater than N0, i.e. that 
349       no band is wider than N0. In the unlikely case it happens, we set
350       everything to zero */
351    /*{
352            int offset = (N0*C - (id+C*N))/2;
353            if (offset > C*N0/16)
354                    offset = C*N0/16;
355            offset -= offset % (C*B);
356            if (offset < 0)
357                    offset = 0;
358            //printf ("%d\n", offset);
359            id += offset;
360    }*/
361    if (id+N>N0)
362       for (j=0;j<N;j++)
363          P[j] = 0;
364    else
365       for (j=0;j<N;j++)
366          P[j] = Y[id++];
367 }
368
369 void intra_fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
370 {
371    fold(m, N, Y, P, N0, B);
372    renormalise_vector(P, Q15ONE, N, 1);
373 }
374