Updated copyright notices
[opus.git] / libcelt / vq.c
1 /* Copyright (c) 2007-2008 CSIRO
2    Copyright (c) 2007-2009 Xiph.Org Foundation
3    Written by Jean-Marc Valin */
4 /*
5    Redistribution and use in source and binary forms, with or without
6    modification, are permitted provided that the following conditions
7    are met:
8    
9    - Redistributions of source code must retain the above copyright
10    notice, this list of conditions and the following disclaimer.
11    
12    - Redistributions in binary form must reproduce the above copyright
13    notice, this list of conditions and the following disclaimer in the
14    documentation and/or other materials provided with the distribution.
15    
16    - Neither the name of the Xiph.org Foundation nor the names of its
17    contributors may be used to endorse or promote products derived from
18    this software without specific prior written permission.
19    
20    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
24    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
25    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
26    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 */
32
33 #ifdef HAVE_CONFIG_H
34 #include "config.h"
35 #endif
36
37 #include "mathops.h"
38 #include "cwrs.h"
39 #include "vq.h"
40 #include "arch.h"
41 #include "os_support.h"
42 #include "rate.h"
43
44 #ifndef M_PI
45 #define M_PI 3.141592653
46 #endif
47
48 static void exp_rotation(celt_norm_t *X, int len, int dir, int stride, int K)
49 {
50    int i, k, iter;
51    celt_word16_t c, s;
52    celt_word16_t gain, theta;
53    celt_norm_t *Xptr;
54    gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,len),(celt_word32_t)(3+len+6*K));
55    /* FIXME: Make that HALF16 instead of HALF32 */
56    theta = SUB16(Q15ONE, HALF32(MULT16_16_Q15(gain,gain)));
57    /*if (len==30)
58    {
59    for (i=0;i<len;i++)
60    X[i] = 0;
61    X[14] = 1;
62 }*/ 
63    c = celt_cos_norm(EXTEND32(theta));
64    s = dir*celt_cos_norm(EXTEND32(SUB16(Q15ONE,theta))); /*  sin(theta) */
65    if (len > 8*stride)
66       stride *= len/(8*stride);
67    iter = 1;
68    for (k=0;k<iter;k++)
69    {
70       /* We could use MULT16_16_P15 instead of MULT16_16_Q15 for more accuracy, 
71       but at this point, I really don't think it's necessary */
72       Xptr = X;
73       for (i=0;i<len-stride;i++)
74       {
75          celt_norm_t x1, x2;
76          x1 = Xptr[0];
77          x2 = Xptr[stride];
78          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
79          *Xptr++      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
80       }
81       Xptr = &X[len-2*stride-1];
82       for (i=len-2*stride-1;i>=0;i--)
83       {
84          celt_norm_t x1, x2;
85          x1 = Xptr[0];
86          x2 = Xptr[stride];
87          Xptr[stride] = MULT16_16_Q15(c,x2) + MULT16_16_Q15(s,x1);
88          *Xptr--      = MULT16_16_Q15(c,x1) - MULT16_16_Q15(s,x2);
89       }
90    }
91    /*if (len==30)
92    {
93    for (i=0;i<len;i++)
94    printf ("%f ", X[i]);
95    printf ("\n");
96    exit(0);
97 }*/
98 }
99
100
101 /** Takes the pitch vector and the decoded residual vector, computes the gain
102     that will give ||p+g*y||=1 and mixes the residual with the pitch. */
103 static void normalise_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, celt_word32_t Ryy)
104 {
105    int i;
106    celt_word32_t g;
107
108    g = celt_rsqrt(Ryy);
109
110    i=0;
111    do
112       X[i] = SHR16(MULT16_16_16(g, iy[i]),1);
113    while (++i < N);
114 }
115
116 void alg_quant(celt_norm_t *X, int N, int K, int spread, ec_enc *enc)
117 {
118    VARDECL(celt_norm_t, y);
119    VARDECL(int, iy);
120    VARDECL(celt_word16_t, signx);
121    int j, is;
122    celt_word16_t s;
123    int pulsesLeft;
124    celt_word32_t sum;
125    celt_word32_t xy, yy;
126    int N_1; /* Inverse of N, in Q14 format (even for float) */
127 #ifdef FIXED_POINT
128    int yshift;
129 #endif
130    SAVE_STACK;
131
132    K = get_pulses(K);
133 #ifdef FIXED_POINT
134    yshift = 13-celt_ilog2(K);
135 #endif
136
137    ALLOC(y, N, celt_norm_t);
138    ALLOC(iy, N, int);
139    ALLOC(signx, N, celt_word16_t);
140    N_1 = 512/N;
141    
142    if (spread)
143       exp_rotation(X, N, 1, spread, K);
144
145    sum = 0;
146    j=0; do {
147       if (X[j]>0)
148          signx[j]=1;
149       else {
150          signx[j]=-1;
151          X[j]=-X[j];
152       }
153       iy[j] = 0;
154       y[j] = 0;
155    } while (++j<N);
156
157    xy = yy = 0;
158
159    pulsesLeft = K;
160
161    /* Do a pre-search by projecting on the pyramid */
162    if (K > (N>>1))
163    {
164       celt_word16_t rcp;
165       sum=0;
166       j=0; do {
167          sum += X[j];
168       }  while (++j<N);
169
170 #ifdef FIXED_POINT
171       if (sum <= K)
172 #else
173       if (sum <= EPSILON)
174 #endif
175       {
176          X[0] = QCONST16(1.f,14);
177          j=1; do
178             X[j]=0;
179          while (++j<N);
180          sum = QCONST16(1.f,14);
181       }
182       /* Do we have sufficient accuracy here? */
183       rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
184       j=0; do {
185 #ifdef FIXED_POINT
186          /* It's really important to round *towards zero* here */
187          iy[j] = MULT16_16_Q15(X[j],rcp);
188 #else
189          iy[j] = floor(rcp*X[j]);
190 #endif
191          y[j] = SHL16(iy[j],yshift);
192          yy = MAC16_16(yy, y[j],y[j]);
193          xy = MAC16_16(xy, X[j],y[j]);
194          y[j] *= 2;
195          pulsesLeft -= iy[j];
196       }  while (++j<N);
197    }
198    celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
199
200    while (pulsesLeft > 0)
201    {
202       int pulsesAtOnce=1;
203       int best_id;
204       celt_word16_t magnitude;
205       celt_word32_t best_num = -VERY_LARGE16;
206       celt_word16_t best_den = 0;
207 #ifdef FIXED_POINT
208       int rshift;
209 #endif
210       /* Decide on how many pulses to find at once */
211       pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
212       if (pulsesAtOnce<1)
213          pulsesAtOnce = 1;
214 #ifdef FIXED_POINT
215       rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
216 #endif
217       magnitude = SHL16(pulsesAtOnce, yshift);
218
219       best_id = 0;
220       /* The squared magnitude term gets added anyway, so we might as well 
221          add it outside the loop */
222       yy = MAC16_16(yy, magnitude,magnitude);
223       /* Choose between fast and accurate strategy depending on where we are in the search */
224          /* This should ensure that anything we can process will have a better score */
225       j=0;
226       do {
227          celt_word16_t Rxy, Ryy;
228          /* Select sign based on X[j] alone */
229          s = magnitude;
230          /* Temporary sums of the new pulse(s) */
231          Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
232          /* We're multiplying y[j] by two so we don't have to do it here */
233          Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
234             
235             /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
236          Rxy is positive because the sign is pre-computed) */
237          Rxy = MULT16_16_Q15(Rxy,Rxy);
238             /* The idea is to check for num/den >= best_num/best_den, but that way
239          we can do it without any division */
240          /* OPT: Make sure to use conditional moves here */
241          if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
242          {
243             best_den = Ryy;
244             best_num = Rxy;
245             best_id = j;
246          }
247       } while (++j<N);
248       
249       j = best_id;
250       is = pulsesAtOnce;
251       s = SHL16(is, yshift);
252
253       /* Updating the sums of the new pulse(s) */
254       xy = xy + MULT16_16(s,X[j]);
255       /* We're multiplying y[j] by two so we don't have to do it here */
256       yy = yy + MULT16_16(s,y[j]);
257
258       /* Only now that we've made the final choice, update y/iy */
259       /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
260       y[j] += 2*s;
261       iy[j] += is;
262       pulsesLeft -= pulsesAtOnce;
263    }
264    j=0;
265    do {
266       X[j] = MULT16_16(signx[j],X[j]);
267       if (signx[j] < 0)
268          iy[j] = -iy[j];
269    } while (++j<N);
270    encode_pulses(iy, N, K, enc);
271    
272    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
273    due to the recursive computation used in quantisation. */
274    normalise_residual(iy, X, N, K, EXTRACT16(SHR32(yy,2*yshift)));
275    if (spread)
276       exp_rotation(X, N, -1, spread, K);
277    RESTORE_STACK;
278 }
279
280
281 /** Decode pulse vector and combine the result with the pitch vector to produce
282     the final normalised signal in the current band. */
283 void alg_unquant(celt_norm_t *X, int N, int K, int spread, ec_dec *dec)
284 {
285    int i;
286    celt_word32_t Ryy;
287    VARDECL(int, iy);
288    SAVE_STACK;
289    K = get_pulses(K);
290    ALLOC(iy, N, int);
291    decode_pulses(iy, N, K, dec);
292    Ryy = 0;
293    i=0;
294    do {
295       Ryy = MAC16_16(Ryy, iy[i], iy[i]);
296    } while (++i < N);
297    normalise_residual(iy, X, N, K, Ryy);
298    if (spread)
299       exp_rotation(X, N, -1, spread, K);
300    RESTORE_STACK;
301 }
302
303 celt_word16_t renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
304 {
305    int i;
306    celt_word32_t E = EPSILON;
307    celt_word16_t rE;
308    celt_word16_t g;
309    celt_norm_t *xptr = X;
310    for (i=0;i<N;i++)
311    {
312       E = MAC16_16(E, *xptr, *xptr);
313       xptr += stride;
314    }
315
316    rE = celt_sqrt(E);
317 #ifdef FIXED_POINT
318    if (rE <= 128)
319       g = Q15ONE;
320    else
321 #endif
322       g = MULT16_16_Q15(value,celt_rcp(SHL32(rE,9)));
323    xptr = X;
324    for (i=0;i<N;i++)
325    {
326       *xptr = PSHR32(MULT16_16(g, *xptr),8);
327       xptr += stride;
328    }
329    return rE;
330 }
331
332 static void fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
333 {
334    int j;
335    int id = N0 % B;
336    /* Here, we assume that id will never be greater than N0, i.e. that 
337       no band is wider than N0. In the unlikely case it happens, we set
338       everything to zero */
339    /*{
340            int offset = (N0*C - (id+C*N))/2;
341            if (offset > C*N0/16)
342                    offset = C*N0/16;
343            offset -= offset % (C*B);
344            if (offset < 0)
345                    offset = 0;
346            //printf ("%d\n", offset);
347            id += offset;
348    }*/
349    if (id+N>N0)
350       for (j=0;j<N;j++)
351          P[j] = 0;
352    else
353       for (j=0;j<N;j++)
354          P[j] = Y[id++];
355 }
356
357 void intra_fold(const CELTMode *m, int N, const celt_norm_t * restrict Y, celt_norm_t * restrict P, int N0, int B)
358 {
359    fold(m, N, Y, P, N0, B);
360    renormalise_vector(P, Q15ONE, N, 1);
361 }
362