Corrected some non-sensical code
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "rate.h"
43
44 #ifndef M_PI
45 #define M_PI 3.141592653
46 #endif
47
48 static void frac_hadamard1(celt_norm *X, int len, int stride, celt_word16 c, celt_word16 s)
49 {
50    int j;
51    celt_norm *x, *y;
52    celt_norm * end;
53
54    j = 0;
55    x = X;
56    y = X+stride;
57    end = X+len;
58    do
59    {
60       celt_norm x1, x2;
61       x1 = *x;
62       x2 = *y;
63       *x++ = EXTRACT16(SHR32(MULT16_16(c,x1) + MULT16_16(s,x2),15));
64       *y++ = EXTRACT16(SHR32(MULT16_16(s,x1) - MULT16_16(c,x2),15));
65       j++;
66       if (j>=stride)
67       {
68          j=0;
69          x+=stride;
70          y+=stride;
71       }
72    } while (y<end);
73
74    /* Reverse samples so that the next level starts from the other end */
75    for (j=0;j<len>>1;j++)
76    {
77       celt_norm tmp = X[j];
78       X[j] = X[len-j-1];
79       X[len-j-1] = tmp;
80    }
81 }
82
83 #define MAX_LEVELS 8
84 static void exp_rotation(celt_norm *X, int len, int dir, int stride, int K)
85 {
86    int i, N=0;
87    int transient;
88    celt_word16 gain, theta;
89    int istride[MAX_LEVELS];
90    celt_word16 c[MAX_LEVELS], s[MAX_LEVELS];
91
92    if (K >= len)
93       return;
94    transient = stride>1;
95    /*if (len>=30)
96    {
97       for (i=0;i<len;i++)
98          X[i] = 0;
99       X[30] = 1;
100       dir = -1;
101       transient = 1;
102    }*/
103    gain = celt_div((celt_word32)MULT16_16(Q15_ONE,len),(celt_word32)(3+len+4*K));
104    /* FIXME: Make that HALF16 instead of HALF32 */
105    theta = MIN16(QCONST16(.25f,15),HALF32(MULT16_16_Q15(gain,gain)));
106    c[0] = celt_cos_norm(EXTEND32(theta));
107    s[0] = celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
108
109    do {
110       istride[N] = stride;
111       stride *= 2;
112       c[N] = c[0];
113       s[N] = s[0];
114       N++;
115    } while (N<MAX_LEVELS && stride < len);
116
117    /* This should help a little bit with the transients */
118    if (transient)
119       c[0] = s[0] = QCONST16(.7071068f, 15);
120
121    /* Needs to be < 0 to prevent gaps on the side of the spreading */
122    if (dir < 0)
123    {
124       for (i=0;i<N;i++)
125          frac_hadamard1(X, len, istride[i], c[i], s[i]);
126    } else {
127       for (i=N-1;i>=0;i--)
128          frac_hadamard1(X, len, istride[i], c[i], s[i]);
129    }
130
131    /* Undo last reversal */
132    for (i=0;i<len>>1;i++)
133    {
134       celt_norm tmp = X[i];
135       X[i] = X[len-i-1];
136       X[len-i-1] = tmp;
137    }
138    /*if (len>=30)
139    {
140       for (i=0;i<len;i++)
141          printf ("%f ", X[i]);
142       printf ("\n");
143       exit(0);
144    }*/
145 }
146
147 /** Takes the pitch vector and the decoded residual vector, computes the gain
148     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
149 static void normalise_residual(int * restrict iy, celt_norm * restrict X, int N, int K, celt_word32 Ryy)
150 {
151    int i;
152 #ifdef FIXED_POINT
153    int k;
154 #endif
155    celt_word32 t;
156    celt_word16 g;
157
158 #ifdef FIXED_POINT
159    k = celt_ilog2(Ryy)>>1;
160 #endif
161    t = VSHR32(Ryy, (k-7)<<1);
162    g = celt_rsqrt_norm(t);
163
164    i=0;
165    do
166       X[i] = EXTRACT16(PSHR32(MULT16_16(g, iy[i]), k+1));
167    while (++i < N);
168 }
169
170 void alg_quant(celt_norm *X, int N, int K, int spread, ec_enc *enc)
171 {
172    VARDECL(celt_norm, y);
173    VARDECL(int, iy);
174    VARDECL(celt_word16, signx);
175    int j, is;
176    celt_word16 s;
177    int pulsesLeft;
178    celt_word32 sum;
179    celt_word32 xy, yy;
180    int N_1; /* Inverse of N, in Q14 format (even for float) */
181 #ifdef FIXED_POINT
182    int yshift;
183 #endif
184    SAVE_STACK;
185
186    K = get_pulses(K);
187 #ifdef FIXED_POINT
188    yshift = 13-celt_ilog2(K);
189 #endif
190
191    ALLOC(y, N, celt_norm);
192    ALLOC(iy, N, int);
193    ALLOC(signx, N, celt_word16);
194    N_1 = 512/N;
195    
196    if (spread)
197       exp_rotation(X, N, 1, spread, K);
198
199    sum = 0;
200    j=0; do {
201       if (X[j]>0)
202          signx[j]=1;
203       else {
204          signx[j]=-1;
205          X[j]=-X[j];
206       }
207       iy[j] = 0;
208       y[j] = 0;
209    } while (++j<N);
210
211    xy = yy = 0;
212
213    pulsesLeft = K;
214
215    /* Do a pre-search by projecting on the pyramid */
216    if (K > (N>>1))
217    {
218       celt_word16 rcp;
219       j=0; do {
220          sum += X[j];
221       }  while (++j<N);
222
223 #ifdef FIXED_POINT
224       if (sum <= K)
225 #else
226       if (sum <= EPSILON)
227 #endif
228       {
229          X[0] = QCONST16(1.f,14);
230          j=1; do
231             X[j]=0;
232          while (++j<N);
233          sum = QCONST16(1.f,14);
234       }
235       /* Do we have sufficient accuracy here? */
236       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
237       j=0; do {
238 #ifdef FIXED_POINT
239          /* It's really important to round *towards zero* here */
240          iy[j] = MULT16_16_Q15(X[j],rcp);
241 #else
242          iy[j] = floor(rcp*X[j]);
243 #endif
244          y[j] = SHL16(iy[j],yshift);
245          yy = MAC16_16(yy, y[j],y[j]);
246          xy = MAC16_16(xy, X[j],y[j]);
247          y[j] *= 2;
248          pulsesLeft -= iy[j];
249       }  while (++j<N);
250    }
251    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
252
253    while (pulsesLeft > 0)
254    {
255       int pulsesAtOnce=1;
256       int best_id;
257       celt_word16 magnitude;
258       celt_word32 best_num = -VERY_LARGE16;
259       celt_word16 best_den = 0;
260 #ifdef FIXED_POINT
261       int rshift;
262 #endif
263       /* Decide on how many pulses to find at once */
264       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
265       if (pulsesAtOnce<1)
266          pulsesAtOnce = 1;
267 #ifdef FIXED_POINT
268       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
269 #endif
270       magnitude = SHL16(pulsesAtOnce, yshift);
271
272       best_id = 0;
273       /* The squared magnitude term gets added anyway, so we might as well 
274          add it outside the loop */
275       yy = MAC16_16(yy, magnitude,magnitude);
276       /* Choose between fast and accurate strategy depending on where we are in the search */
277          /* This should ensure that anything we can process will have a better score */
278       j=0;
279       do {
280          celt_word16 Rxy, Ryy;
281          /* Select sign based on X[j] alone */
282          s = magnitude;
283          /* Temporary sums of the new pulse(s) */
284          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
285          /* We're multiplying y[j] by two so we don't have to do it here */
286          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
287             
288             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
289          Rxy is positive because the sign is pre-computed) */
290          Rxy = MULT16_16_Q15(Rxy,Rxy);
291             /* The idea is to check for num/den >= best_num/best_den, but that way
292          we can do it without any division */
293          /* OPT: Make sure to use conditional moves here */
294          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
295          {
296             best_den = Ryy;
297             best_num = Rxy;
298             best_id = j;
299          }
300       } while (++j<N);
301       
302       j = best_id;
303       is = pulsesAtOnce;
304       s = SHL16(is, yshift);
305
306       /* Updating the sums of the new pulse(s) */
307       xy = xy + MULT16_16(s,X[j]);
308       /* We're multiplying y[j] by two so we don't have to do it here */
309       yy = yy + MULT16_16(s,y[j]);
310
311       /* Only now that we've made the final choice, update y/iy */
312       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
313       y[j] += 2*s;
314       iy[j] += is;
315       pulsesLeft -= pulsesAtOnce;
316    }
317    j=0;
318    do {
319       X[j] = MULT16_16(signx[j],X[j]);
320       if (signx[j] < 0)
321          iy[j] = -iy[j];
322    } while (++j<N);
323    encode_pulses(iy, N, K, enc);
324    
325    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
326    due to the recursive computation used in quantisation. */
327    normalise_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
328    if (spread)
329       exp_rotation(X, N, -1, spread, K);
330    RESTORE_STACK;
331 }
332
333
334 /** Decode pulse vector and combine the result with the pitch vector to produce
335     the final normalised signal in the current band. */
336 void alg_unquant(celt_norm *X, int N, int K, int spread, ec_dec *dec)
337 {
338    int i;
339    celt_word32 Ryy;
340    VARDECL(int, iy);
341    SAVE_STACK;
342    K = get_pulses(K);
343    ALLOC(iy, N, int);
344    decode_pulses(iy, N, K, dec);
345    Ryy = 0;
346    i=0;
347    do {
348       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
349    } while (++i < N);
350    normalise_residual(iy, X, N, K, Ryy);
351    if (spread)
352       exp_rotation(X, N, -1, spread, K);
353    RESTORE_STACK;
354 }
355
356 celt_word16 renormalise_vector(celt_norm *X, celt_word16 value, int N, int stride)
357 {
358    int i;
359    celt_word32 E = EPSILON;
360    celt_word16 g;
361    celt_word32 t;
362    celt_norm *xptr = X;
363    for (i=0;i<N;i++)
364    {
365       E = MAC16_16(E, *xptr, *xptr);
366       xptr += stride;
367    }
368 #ifdef FIXED_POINT
369    int k = celt_ilog2(E)>>1;
370 #endif
371    t = VSHR32(E, (k-7)<<1);
372    g = MULT16_16_Q15(value, celt_rsqrt_norm(t));
373
374    xptr = X;
375    for (i=0;i<N;i++)
376    {
377       *xptr = EXTRACT16(PSHR32(MULT16_16(g, *xptr), k+1));
378       xptr += stride;
379    }
380    return celt_sqrt(E);
381 }
382
383 static void fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B)
384 {
385    int j;
386    int id = N0 % B;
387    while (id < m->eBands[start])
388       id += B;
389    /* Here, we assume that id will never be greater than N0, i.e. that 
390       no band is wider than N0. In the unlikely case it happens, we set
391       everything to zero */
392    /*{
393            int offset = (N0*C - (id+C*N))/2;
394            if (offset > C*N0/16)
395                    offset = C*N0/16;
396            offset -= offset % (C*B);
397            if (offset < 0)
398                    offset = 0;
399            //printf ("%d\n", offset);
400            id += offset;
401    }*/
402    if (id+N>N0)
403       for (j=0;j<N;j++)
404          P[j] = 0;
405    else
406       for (j=0;j<N;j++)
407          P[j] = Y[id++];
408 }
409
410 void intra_fold(const CELTMode *m, int start, int N, const celt_norm * restrict Y, celt_norm * restrict P, int N0, int B)
411 {
412    fold(m, start, N, Y, P, N0, B);
413    renormalise_vector(P, Q15ONE, N, 1);
414 }
415